# Python Symbolic Information Theoretic Inequality Prover
# Copyright (C) 2020 Cheuk Ting Li
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
"""
Python Symbolic Information Theoretic Inequality Prover
Version 1.1.7
Copyright (C) 2020 Cheuk Ting Li
The working principle of PSITIP (existential information inequalities) is described in the following article:
C. T. Li, "An Automated Theorem Proving Framework for Information-Theoretic Results,"
in IEEE Transactions on Information Theory, vol. 69, no. 11, pp. 6857-6877, Nov. 2023.
Link: https://ieeexplore.ieee.org/document/10185937
Preprint: https://arxiv.org/pdf/2101.12370.pdf
Based on the general method of using linear programming for proving information
theoretic inequalities described in the following work:
R. W. Yeung, "A new outlook on Shannon's information measures,"
IEEE Trans. Inform. Theory, vol. 37, pp. 466-474, May 1991.
R. W. Yeung, "A framework for linear information inequalities,"
IEEE Trans. Inform. Theory, vol. 43, pp. 1924-1934, Nov 1997.
Z. Zhang and R. W. Yeung, "On characterization of entropy function via information inequalities,"
IEEE Trans. Inform. Theory, vol. 44, pp. 1440-1452, Jul 1998.
This linear programming approach was used in the ITIP software developed by
Raymond W. Yeung and Ying-On Yan ( http://user-www.ie.cuhk.edu.hk/~ITIP/ ).
Usage:
This library requires Python 3. Python 2 is not supported.
This library requires either PuLP, Pyomo or scipy for sparse linear programming.
Open an interactive Python console in the directory containing psitip.py.
from psitip import *
X, Y, Z = rv("X", "Y", "Z")
# Check for always true statements
# Use H(X + Y | Z + W) for the entropy H(X,Y|Z,W)
# Use I(X & Y + Z | W) for the mutual information I(X;Y,Z|W)
# As in the case in ITIP, a False return value does not mean the inequality is
# false. It only means that it cannot be deduced by Shannon-type inequalities.
bool(I(X & Y) - H(X + Z) <= 0) # return True
bool(I(X & Y) == H(X) - H(X | Y)) # return True
# Each constraint (e.g. I(X & Y) - H(X + Z) <= 0 above) generates a Region object
# that can be intersected with each other (using the "&" operator). Casting a
# Region to bool returns whether the constraints in the Region are always satisfied.
# Use A >> B or A <= B to check if the conditions in A implies those in B
# The "<=" operator checks whether the region on the left is a subset of the
# region on the right.
# Note that "<=" does NOT denote implication (which is the opposite direction).
bool(((I(X & Y) == 0) & (I(X+Y & Z) == 0)) >> (I(X & Z) == 0)) # return True
See test.py for more usage examples.
WARNING: Nested implication may produce incorrect results for some cases.
Use at your own risk. It is advisable to check whether the output auxiliary
random variables are indeed valid.
"""
import itertools
import collections
import array
import fractions
import warnings
import types
import functools
import math
import time
import logging
import random
import contextlib
import io
import heapq
try:
import os
import os.path
except ImportError:
os = None
try:
import numpy
import numpy.linalg
except ImportError:
numpy = None
try:
import scipy
import scipy.sparse
import scipy.optimize
import scipy.spatial
import scipy.special
import scipy.linalg
except ImportError:
scipy = None
try:
# Suppress stdout of pulp during import
# with contextlib.redirect_stdout(open(os.devnull, "w")):
with contextlib.redirect_stdout(io.StringIO()):
import pulp
except ImportError:
pulp = None
except Exception as err:
warnings.warn(str(err), RuntimeWarning)
pulp = None
try:
import pyomo.environ as pyo
from pyomo.opt import SolverFactory
logging.getLogger("pyomo.core").setLevel(logging.ERROR)
except ImportError:
pyo = None
try:
import ortools
import ortools.linear_solver
import ortools.linear_solver.pywraplp
except ImportError:
ortools = None
cddver = None
try:
import cdd
try:
cdd.Matrix([[1]], number_type="float")
cddver = "2"
except:
cddver = "3"
except ImportError:
cdd = None
def cdd_matrix_from_array(a, lin_set=None, number_type="float", rep_type=None):
if cdd is None:
return None
if rep_type is None:
rep_type = cdd.RepType.GENERATOR
if lin_set is None:
lin_set = set()
try:
r = cdd.matrix_from_array(a, lin_set = lin_set, rep_type=rep_type)
except AttributeError:
r = cdd.Matrix(a, number_type=number_type)
r.rep_type = rep_type
r.lin_set = lin_set
return r
# if cddver == "2":
# r = cdd.Matrix(a, number_type=number_type)
# r.rep_type = rep_type
# r.lin_set = lin_set
# return r
# else:
# return cdd.matrix_from_array(a, lin_set = lin_set, rep_type=rep_type)
def cdd_convert(mat):
if cdd is None:
return None
try:
poly = cdd.polyhedron_from_matrix(mat)
mat2 = cdd.copy_output(poly)
return (list(mat2.array), set(mat2.lin_set))
except AttributeError:
poly = cdd.Polyhedron(mat)
ineqs = poly.get_inequalities()
lset = set(ineqs.lin_set)
return (list(ineqs), lset)
# if cddver == "2":
# poly = cdd.Polyhedron(mat)
# ineqs = poly.get_inequalities()
# lset = set(ineqs.lin_set)
# return (list(ineqs), lset)
# else:
# poly = cdd.polyhedron_from_matrix(mat)
# mat2 = cdd.copy_output(poly)
# return (list(mat2.array), set(mat2.lin_set))
def cdd_col_size(mat):
try:
if len(mat.array) == 0:
return 0
return len(mat.array[0])
except AttributeError:
return mat.col_size
def cdd_row_size(mat):
try:
return len(mat.array)
except AttributeError:
return mat.row_size
def cdd_mat_array(mat):
try:
return mat.array
except AttributeError:
return mat
def cdd_extend(mat, line):
try:
cdd.matrix_append_to(mat, cdd_matrix_from_array(line))
except AttributeError:
mat.extend(line)
try:
import z3
except ImportError:
z3 = None
try:
import graphviz
except ImportError:
graphviz = None
try:
import matplotlib.pyplot as plt
import matplotlib.patches
import matplotlib.lines
import matplotlib.patheffects
import matplotlib.colors
from matplotlib.collections import PatchCollection
except ImportError:
matplotlib = None
plt = None
try:
import plotly
import plotly.offline
import plotly.graph_objects
except ImportError:
plotly = None
try:
import torch
import torch.optim
except ImportError:
torch = None
try:
import torch.linalg
except ImportError:
pass
try:
import IPython
import IPython.display
except ImportError:
IPython = None
try:
import lark
from lark import v_args as lark_v_args
from lark import Transformer as lark_Transformer
except ImportError:
lark = None
lark_v_args = lambda inline: (lambda x: x)
lark_Transformer = object
try:
import pylatexenc
from pylatexenc.latex2text import LatexNodes2Text
except ImportError:
pylatexenc = None
LatexNodes2Text = None
# try:
# import sympy
# import sympy.matrices.normalforms
# except ImportError:
# sympy = None
# try:
# import smithnormalform
# import smithnormalform.matrix
# import smithnormalform.snfproblem
# import smithnormalform.z
# except ImportError:
# smithnormalform = None
class LinearProgType:
NIL = 0
H = 1 # H(X_C)
HC1BN = 2 # H(X | Y_C)
HMIN = 3
class PsiOpts:
""" Options
Attributes:
solver : The linear programming solver used
"scipy" : scipy.optimize.linprog
"pulp.cbc" : PuLP with CBC
"pulp.glpk" : PuLP with GLPK
"pyomo.glpk" : Pyomo with GLPK
str_style : Style of string conversion
STR_STYLE_STANDARD : I(X,Y;Z|W)
STR_STYLE_PSITIP : I(X+Y&Z|W)
STR_STYLE_LATEX : I(X,Y;Z|W)
lptype : Linear programming mode
LinearProgType.H : Classical
LinearProgType.HC1BN : Bayesian Network optimization
verbose_auxsearch : Print auxiliary RV search progress
verbose_lp : Print linear programming parameters
rename_char : Char appended when renaming auxiliary RV
eps : Epsilon for float comparison
"""
STR_STYLE_STANDARD = 1
STR_STYLE_PSITIP = 2
STR_STYLE_LATEX = 4
STR_STYLE_LATEX_ARRAY = 1 << 5
STR_STYLE_LATEX_FRAC = 1 << 6
STR_STYLE_LATEX_QUANTAFTER = 1 << 7
STR_STYLE_MARKOV = 1 << 8
STR_STYLE_REM_SYMBOLS = 1 << 9
STR_STYLE_GRAPHVIZ = 1 << 10
SFRL_LEVEL_SINGLE = 1
SFRL_LEVEL_MULTIPLE = 2
global_index = None
stats = {
"numlp": 0
}
settings = {
"ent_coeff": 1.0 / math.log(2.0),
"random": None,
"quantum": False,
"hge0": True,
"hcge0": True,
"ige0": True,
"icge0": True,
"figsize": None,
"eps": 1e-10,
"eps_lp": 1e-5,
"eps_check": 1e-6,
"eps_violate_cutoff": 1e-1,
"max_denom": 1000000,
"max_denom_mul": 10000,
"max_denom_lp": 10000,
"max_denom_try": 12,
"condition_included": True,
"avoid_enabled": True,
"str_style": 2 + (1 << 8),
"str_style_std": 1 + (1 << 8),
"str_style_latex": 4 + (1 << 8),
"str_style_repr": 2 + (1 << 8),
"str_tosort": False,
"str_lhsreal": True,
"str_eqn_prefer_ge": False,
"str_eps_hide": True,
"str_proof_note": False,
"rename_char": "_",
"fcn_suffix": "@@32768@@#fcn",
"lhsvar": "real",
"meta_subs_criteria": False,
"truth": None,
"proof": None,
"use_global_index": True,
"indreg_enabled": True,
"maxent_lex_enabled": True,
"example_opt_num_points_mul": 1,
"timer_start": None,
"timer_end": None,
"stop_file": None,
"stop_file_last_check": None,
"stop_file_exists_cache": False,
"solver": None,
"lptype": LinearProgType.HMIN,
"lptype_H_if_proof": False,
"lp_zero_group": True,
"lp_no_zero_group_if_proof": True,
"lp_bounded": False,
"lp_ubound": 1e4,
"lp_eps": 1e-3,
"lp_eps_obj": 1e-4,
"lp_zero_cutoff": -1e-5,
"lp_cond_maximize": True,
"lp_dual_form": False,
"lp_dual_form_if_proof": True,
"lp_dual_form_discover": False,
"lp_bnet_reverse": False,
"lp_bnet_hc": True,
"lp_save_enabled": True,
"fcn_mode": 1,
"solver_scipy_maxsize": -1,
"pulp_options": None,
"pyomo_options": {},
"verbose_solver": False,
"simplify_enabled": True,
"simplify_quick": False,
"simplify_reduce_coeff": True,
"simplify_balance": False,
"simplify_remove_missing_aux": True,
"simplify_aux": True,
"simplify_aux_combine": True,
"simplify_aux_commonpart": True,
"simplify_aux_xor_len": 1,
"simplify_aux_empty": False,
"simplify_aux_recombine": False,
"simplify_pair": True,
"simplify_redundant": True,
"simplify_redundant_full": True,
"simplify_bayesnet": True,
"simplify_expr_exhaust": True,
"simplify_redundant_op": True,
"simplify_union": False,
"simplify_aux_relax": True,
"simplify_sort": True,
"simplify_regterm": True,
"simplify_aux_hull": False,
"simplify_aux_hull_lower_complexity": True,
"simplify_aux_eq": False,
"simplify_aux_strengthen": False,
"simplify_num_iter": 1,
"term_allow": None,
"istorch": (torch is not None),
"opt_optimizer": "SLSQP", #"sgd",
"opt_eps_denom": 1e-8,
"opt_eps_tol": 1e-7,
"opt_learnrate": 0.04,
"opt_learnrate2": 0.02,
"opt_momentum": 0.0,
"opt_num_iter": 800,
"opt_num_iter2": 0,
"opt_num_points": 1,
"opt_eps_converge": 1e-9,
"opt_num_hop": 1,
"opt_hop_temp": 0.05,
"opt_hop_prob": 0.2,
"opt_alm_rho": 1.0,
"opt_alm_rho_pow": 1.0,
"opt_alm_step": 5,
"opt_alm_penalty": 20.0,
"opt_aux_card": 5,
"opt_example_card": (2, 4),
"imp_noncircular": True,
"imp_noncircular_allaux": False,
"imp_simplify": False,
"prefer_expand": True,
"tensorize_simplify": False,
"eliminate_rays": False,
"ignore_must": False,
"forall_multiuse": True,
"forall_multiuse_numsave": 128,
"auxsearch_local": True,
"auxsearch_leaveone": False,
"auxsearch_leaveone_add_ineq": True,
"auxsearch_strengthen": False,
"auxsearch_aux_strengthen": True,
"init_leaveone": True,
"auxsearch_max_iter": 0,
"auxsearch_op_casesteplimit": 16,
"auxsearch_op_caselimit": 512,
"auxsearch_sandwich": True,
"auxsearch_sandwich_inc": False,
"auxsearch_max_numupdate_one_if_proof": True,
"flatten_minmax_elim": False,
"flatten_distribute": True,
"flatten_distribute_multi": False,
"presolve_aux_hull": False,
"presolve_aux_hull_quick": True,
"presolve_aux_eq": False,
"presolve_aux_eq_quick": True,
"presolve_simplify": False,
"alg_group_nit": 10,
"alg_normalize": True,
"bayesnet_semigraphoid_iter": 1000,
"proof_enabled": False,
"proof_nowrite": False,
"proof_step_dualsum": True,
"proof_step_dualsum_short": True,
"proof_step_dualsum_exhaust": True,
"proof_step_dualsum_prog": False,
"proof_step_chain": True,
"proof_step_optimize": True,
"proof_step_simplify": False,
"proof_step_term_simplify": False,
"proof_step_compress": True,
"proof_step_term_nshuffle": 1,
"proof_step_bayesnet": True,
"proof_deficit_separate": True,
"proof_repeat_implicant": False,
"proof_yield_one": False,
"proof_note": True,
"proof_note_color": None,
"proof_note_newline": 80,
"proof_note_skip_trivial": True,
"proof_step_expand_def": False,
"proof_noskip": False,
"repr_simplify": True,
"repr_check": False,
"repr_latex": False,
"venn_latex": "mathregular",
"venn_style": "",
"plot_lib": None,
"graph_group_color": "transparent",
"graph_group_fillcolor": "grey93",
"codingmodel_node_shape": "rect",
"codingmodel_node_fillcolor": "grey93",
"codingmodel_channel_shape": "rect",
"codingmodel_channel_fillcolor": "white",
"discover_hull_facet_enumerate": False,
"discover_hull_facet_enumerate_numyield": 300,
"discover_hull_frac_enabled": True,
"discover_hull_frac_denom": 1000,
"discover_max_facet": None, # 1000000,
"discover_num_simplex": 0,
"str_ineq_sep": ",",
"str_brace_l": "{",
"str_brace_r": "}",
"str_eq": "==",
"str_float": False,
"str_float_dp": 5,
"solve_display_reg": None,
"use_pylatexenc": True,
"graphviz_text_convert": True,
"latex_H": "H",
"latex_I": "I",
"latex_rv_delim": ",",
"latex_cond": "|",
"latex_sup": "\\sup",
"latex_inf": "\\inf",
"latex_max": "\\max",
"latex_min": "\\min",
"latex_exists": "\\exists",
"latex_forall": "\\forall",
"latex_quantifier_sep": ":\\,",
"latex_indep": "{\\perp\\!\\!\\perp}",
"latex_markov": "\\leftrightarrow",
"latex_mi_delim": ";",
"latex_list_bracket_l": "[",
"latex_list_bracket_r": "]",
"latex_matimplies": "\\Rightarrow",
"latex_equiv": "\\Leftrightarrow",
"latex_implies": "\\Rightarrow",
"latex_times": "\\cdot",
"latex_prob": "\\mathbf{P}",
"latex_rv_empty": "\\emptyset",
"latex_region_universe": "\\top",
"latex_region_empty": "\\bot",
"latex_tautology": "\\top",
"latex_contradiction": "\\bot",
"latex_unknown": "?",
"latex_or": "\\vee",
"latex_and": "\\wedge",
"latex_infty": "\\infty",
"latex_eps": "\\epsilon",
"latex_because": "\\because",
"latex_therefore": "\\therefore",
"latex_is_typical": "\\in\\mathcal{T}",
"latex_subs": ":=",
"latex_subs_bracket_l": "\\{",
"latex_subs_bracket_r": ".",
"latex_group_mul": "", # "\\cdot",
"latex_group_id": "\\mathrm{id}",
"latex_group_add": "+",
"latex_group_minus": "-",
"latex_group_additive": False,
"latex_color": None,
"latex_line_len": None,
"verbose_lp": False,
"verbose_lp_cons": False,
"verbose_auxsearch": False,
"verbose_auxsearch_step": False,
"verbose_auxsearch_result": False,
"verbose_auxsearch_cache": False,
"verbose_auxsearch_step_cached": False,
"verbose_auxsearch_op": False,
"verbose_auxsearch_op_step": False,
"verbose_auxsearch_op_detail": False,
"verbose_auxsearch_op_detail2": False,
"verbose_subset": False,
"verbose_sfrl": False,
"verbose_flatten": False,
"verbose_eliminate": False,
"verbose_eliminate_toreal": False,
"verbose_semigraphoid": False,
"verbose_proof": False,
"verbose_discover": False,
"verbose_discover_ineq": False,
"verbose_discover_detail": False,
"verbose_discover_outer": False,
"verbose_discover_terms": False,
"verbose_discover_terms_inner": False,
"verbose_discover_terms_outer": False,
"verbose_opt": False,
"verbose_opt_step": False,
"verbose_opt_step_var": False,
"verbose_float_dp": 8,
"verbose_commmodel": False,
"verbose_codingmodel": False,
"verbose_proof_step": False,
"verbose_aux_reduced": False,
"verbose_partialorder": False,
"sfrl_level": 0,
"sfrl_maxsize": 1,
"sfrl_gap": ""
}
@staticmethod
def set_setting_dict(d, key, value):
if key.endswith("_all"):
keyprefix = key[:-len("_all")]
for key2 in d:
if key2.startswith(keyprefix):
PsiOpts.set_setting_dict(d, key2, value)
return
if key == "sfrl":
if value == "no":
d["sfrl_level"] = 0
else:
d["sfrl_level"] = max(d["sfrl_level"], PsiOpts.SFRL_LEVEL_SINGLE)
if value == "frl":
d["sfrl_gap"] = ""
elif value.startswith("sfrl_gap."):
d["sfrl_gap"] = value[value.index(".") + 1 :]
elif value == "sfrl_nogap":
d["sfrl_gap"] = "zero"
elif key == "str_style":
d["str_style"] = iutil.convert_str_style(value)
if isinstance(value, str) and value.lower() == "aitip":
d["str_ineq_sep"] = ""
d["str_brace_l"] = " "
d["str_brace_r"] = ""
d["str_eq"] = "="
d["str_float"] = True
elif key == "timer" or key == "timelimit":
if value is None:
d["timer_start"] = None
d["timer_end"] = None
else:
if isinstance(value, str):
value = iutil.get_duration_ms(value)
curtime = time.time() * 1000
d["timer_start"] = curtime
d["timer_end"] = curtime + float(value)
elif key == "stop_file":
if value != d["stop_file"]:
d["stop_file"] = value
d["stop_file_last_check"] = None
elif key == "simplify_level":
d["simplify_enabled"] = value >= 1
d["simplify_quick"] = value <= 2
# d["simplify_remove_missing_aux"] = True
d["simplify_aux"] = value >= 4
d["simplify_aux_combine"] = value >= 4
d["simplify_aux_commonpart"] = value >= 5
if value >= 9:
d["simplify_aux_xor_len"] = 4
elif value >= 7:
d["simplify_aux_xor_len"] = 3
else:
d["simplify_aux_xor_len"] = 1
d["simplify_aux_empty"] = value >= 6
d["simplify_aux_recombine"] = value >= 10
d["simplify_pair"] = value >= 2
d["simplify_redundant"] = value >= 3
d["simplify_redundant_full"] = value >= 4
d["simplify_bayesnet"] = value >= 4
d["simplify_redundant_op"] = value >= 4
d["simplify_union"] = value >= 8
d["simplify_aux_hull_lower_complexity"] = value >= 5
d["simplify_num_iter"] = 2 if value >= 9 else 1
elif key == "simplify_strengthen":
d["simplify_aux_strengthen"] = value
d["simplify_aux_eq"] = value
elif key == "simplify_relax":
d["simplify_aux_relax"] = value
d["simplify_aux_hull"] = value
elif key == "auxsearch_level":
d["auxsearch_leaveone"] = value >= 6
d["auxsearch_strengthen"] = value >= 9
d["presolve_aux_hull"] = value >= 7
d["presolve_aux_hull_quick"] = value >= 4
d["presolve_aux_eq"] = value >= 8
d["presolve_aux_eq_quick"] = value >= 5
d["presolve_simplify"] = value >= 10
elif key == "level":
PsiOpts.set_setting_dict(d, "simplify_level", value)
PsiOpts.set_setting_dict(d, "auxsearch_level", value)
elif key == "ent_base":
d["ent_coeff"] = 1.0 / math.log(value)
elif key == "term_allow":
if isinstance(value, str):
value = value.lower()
if value == "h":
d["term_allow"] = TermAllowType.H
elif value == "i":
d["term_allow"] = TermAllowType.H | TermAllowType.I
elif value == "hc" or value == "ch":
d["term_allow"] = TermAllowType.H | TermAllowType.HC
elif value == "ic" or value == "ci":
d["term_allow"] = TermAllowType.H | TermAllowType.HC | TermAllowType.I | TermAllowType.IC
elif value == "all":
d["term_allow"] = None
else:
d["term_allow"] = value
elif key == "term_allow_hc":
if value:
d["term_allow"] |= TermAllowType.HC
else:
if d["term_allow"] is None:
d["term_allow"] = TermAllowType.DEFAULT
d["term_allow"] &= ~(TermAllowType.HC)
elif key == "term_allow_i":
if value:
d["term_allow"] |= TermAllowType.I
else:
if d["term_allow"] is None:
d["term_allow"] = TermAllowType.DEFAULT
d["term_allow"] &= ~(TermAllowType.I)
elif key == "term_allow_ic":
if value:
d["term_allow"] |= TermAllowType.IC
else:
if d["term_allow"] is None:
d["term_allow"] = TermAllowType.DEFAULT
d["term_allow"] &= ~(TermAllowType.IC)
elif key == "term_allow_i3":
if value:
d["term_allow"] |= TermAllowType.I3
else:
if d["term_allow"] is None:
d["term_allow"] = TermAllowType.DEFAULT
d["term_allow"] &= ~(TermAllowType.I3)
elif key == "verbose_auxsearch_all":
d["verbose_auxsearch"] = value
d["verbose_auxsearch_step"] = value
d["verbose_auxsearch_result"] = value
elif key == "verbose_proof":
d["verbose_proof"] = value
if value:
d["proof_enabled"] = value
if d["proof"] is None:
d["proof"] = ProofObj.empty()
elif key == "proof_enabled":
d["proof_enabled"] = value
if value:
if d["proof"] is None:
d["proof"] = ProofObj.empty()
elif key == "pulp_solver":
d["solver"] = "pulp.other"
iutil.pulp_solver = value
elif key == "lptype":
if value == "H":
d[key] = LinearProgType.H
elif value == "HMIN":
d[key] = LinearProgType.HMIN
elif isinstance(value, str):
d[key] = LinearProgType.HC1BN
else:
d[key] = value
elif key == "truth":
if value is None:
d["truth"] = None
else:
d["truth"] = value.copy()
elif key == "truth_add":
if value is not None:
if d["truth"] is None:
d["truth"] = value.copy()
else:
d["truth"] = d["truth"] & value
elif key == "cases":
d["auxsearch_leaveone"] = value
elif key == "proof_add":
if d["proof"] is None:
d["proof"] = ProofObj.empty()
if PsiOpts.settings.get("verbose_proof", False):
print(value.tostring(prev = d["proof"]))
print("")
d["proof"] += value
elif key == "proof_clear":
if value:
if d["proof"] is not None:
d["proof"].clear()
elif key == "proof_new":
if value:
d["proof_enabled"] = value
d["proof"] = ProofObj.empty()
if isinstance(value, str):
PsiOpts.set_setting_dict(d, "proof_option", value)
elif key == "proof_shorten":
d["lp_dual_form_if_proof"] = value
d["lp_no_zero_group_if_proof"] = value
elif key == "proof_detail":
d["lptype_H_if_proof"] = value
d["proof_noskip"] = value
d["proof_step_compress"] = not value
elif key == "proof_option":
for c in value.split(","):
c = c.strip().lower()
PsiOpts.set_setting_dict(d, "proof_" + c, True)
elif key == "proof_branch":
if value:
d["proof_enabled"] = value
if d["proof"] is None:
d["proof"] = ProofObj.empty()
else:
d["proof"] = d["proof"].copy()
if isinstance(value, str):
PsiOpts.set_setting_dict(d, "proof_option", value)
elif key == "proof_step_in":
if d["proof"] is None:
d["proof"] = ProofObj.empty()
d["proof"] = d["proof"].step_in(value)
elif key == "proof_step_out":
if d["proof"] is None:
d["proof"] = ProofObj.empty()
d["proof"] = d["proof"].step_out()
elif key == "note":
d["str_proof_note"] = value
elif key == "stats_reset":
PsiOpts.stats["numlp"] = 0
elif key == "random_seed":
rnd = numpy.random.default_rng(value)
d["random"] = rnd
elif key == "opt_singlepass":
d["opt_learnrate"] = 0.04
d["opt_num_iter"] = 800
d["opt_num_iter2"] = 0
d["opt_num_points"] = 1
d["opt_num_hop"] = 1
elif key == "opt_basinhopping":
if value:
d["opt_learnrate"] = 0.12
d["opt_learnrate2"] = 0.02 #0.001
d["opt_num_points"] = 5
d["opt_num_iter"] = 15
d["opt_num_iter2"] = 500
d["opt_num_hop"] = 20
else:
PsiOpts.set_setting_dict(d, "opt_singlepass", True)
elif key == "opt_learnrate_mul":
d["opt_learnrate"] *= value
d["opt_learnrate2"] *= value
elif key == "opt_num_iter_mul":
d["opt_num_iter"] = int(d["opt_num_iter"] * value)
d["opt_num_iter2"] = int(d["opt_num_iter2"] * value)
elif key == "opt_num_points_mul":
d["opt_num_points"] = int(d["opt_num_points"] * value)
else:
if key not in d:
raise KeyError("Option '" + str(key) + "' not found.")
d[key] = value
@staticmethod
def apply_dict(d):
IBaseObj.set_repr_latex(d["repr_latex"])
@staticmethod
def set_setting(**kwargs):
for key, value in kwargs.items():
PsiOpts.set_setting_dict(PsiOpts.settings, key, value)
PsiOpts.apply_dict(PsiOpts.settings)
@staticmethod
def setting(**kwargs):
PsiOpts.set_setting(**kwargs)
@staticmethod
def get_setting(key, defaultval = None):
if key in PsiOpts.settings:
return PsiOpts.settings[key]
return defaultval
@staticmethod
def get_proof():
return PsiOpts.settings["proof"]
@staticmethod
def set_proof(value):
if PsiOpts.settings["proof"] is None:
PsiOpts.settings["proof"] = value
else:
PsiOpts.settings["proof"].copy_(value)
@staticmethod
def get_plotlib(lib = None):
if lib is None:
lib = PsiOpts.settings["plot_lib"]
if plotly is not None and (lib is None or lib == "plotly"):
return "plotly"
elif plt is not None and (lib is None or lib == "matplotlib"):
return "matplotlib"
return None
@staticmethod
def get_random():
rnd = PsiOpts.settings["random"]
if rnd is None:
# rnd = random.Random()
rnd = numpy.random.default_rng()
PsiOpts.settings["random"] = rnd
return rnd
@staticmethod
def get_truth():
return PsiOpts.settings["truth"]
@staticmethod
def timer_left():
if PsiOpts.settings["timer_end"] is None:
return None
curtime = time.time() * 1000
return PsiOpts.settings["timer_end"] - curtime
@staticmethod
def timer_left_sec():
r = PsiOpts.timer_left()
if r is None:
return None
return int(round(r / 1000.0))
@staticmethod
def is_timer_ended_time():
if PsiOpts.settings["timer_end"] is None:
return False
curtime = time.time() * 1000
return curtime > PsiOpts.settings["timer_end"]
@staticmethod
def is_timer_ended_file():
if PsiOpts.settings["stop_file"] is None:
return False
curtime = time.time() * 1000
if PsiOpts.settings["stop_file_last_check"] is not None and curtime <= PsiOpts.settings["stop_file_last_check"] + 5 * 1000:
return PsiOpts.settings["stop_file_exists_cache"]
PsiOpts.settings["stop_file_last_check"] = curtime
PsiOpts.settings["stop_file_exists_cache"] = os.path.exists(PsiOpts.settings["stop_file"])
return PsiOpts.settings["stop_file_exists_cache"]
@staticmethod
def is_timer_ended():
if PsiOpts.is_timer_ended_time():
return True
if PsiOpts.is_timer_ended_file():
return True
return False
@staticmethod
def has_timer():
return PsiOpts.settings["timer_end"] is not None or PsiOpts.settings["stop_file"] is not None
@staticmethod
def get_pyomo_options():
r = dict(PsiOpts.settings["pyomo_options"])
# if "tee" not in r:
# r["tee"] = PsiOpts.settings["verbose_solver"]
# if "verbose" not in r:
# r["verbose"] = False
timelimit = PsiOpts.timer_left_sec()
if timelimit is not None:
csolver = iutil.get_solver()
if csolver == "pyomo.glpk":
r["tmlim"] = timelimit
elif csolver == "pyomo.cplex":
r["timelimit"] = timelimit
elif csolver == "pyomo.gurobi":
r["TimeLimit"] = timelimit
elif csolver == "pyomo.cbc":
r["seconds"] = timelimit
return r
@staticmethod
def setting_strengthen_sign(s):
if s.startswith("simplify_"):
s = s[len("simplify_"):]
if s == "aux_eq" or s == "aux_strengthen" or s == "strengthen":
return 1
elif s == "aux_relax" or s == "aux_hull" or s == "relax":
return -1
return 0
@staticmethod
def setting_strengthen_split(d):
r0 = dict()
r1 = dict()
for a, b in d.items():
if b is True:
t = PsiOpts.setting_strengthen_sign(a)
if t >= 0:
r0[a] = b
if t <= 0:
r1[a] = b
else:
r0[a] = b
r1[a] = b
return (r0, r1)
def __init__(self, **kwargs):
"""
Options.
Parameters
----------
**kwargs : TYPE
DESCRIPTION.
Returns
-------
None.
"""
self.cur_settings = PsiOpts.settings.copy()
for key, value in kwargs.items():
PsiOpts.set_setting_dict(self.cur_settings, key, value)
def __enter__(self):
PsiOpts.settings, self.cur_settings = self.cur_settings, PsiOpts.settings
PsiOpts.apply_dict(PsiOpts.settings)
return PsiOpts.settings
def __exit__(self, exc_type, exc_value, exc_traceback):
PsiOpts.settings, self.cur_settings = self.cur_settings, PsiOpts.settings
PsiOpts.apply_dict(PsiOpts.settings)
class iutil:
"""Common utilities
"""
solver_list = ["ortools.GLOP", "pulp.glpk", "pyomo.glpk", "pulp.cbc", "scipy", "z3"]
pulp_solver = None
pulp_solvers = {}
cur_count = 0
cur_count_name = {}
@staticmethod
def display_latex(s, ismath = True, metadata = None):
color = PsiOpts.settings["latex_color"]
if color is not None:
s = "\\color{" + color + "}{" + s + "}"
if ismath:
r = IPython.display.Math(s, metadata = metadata)
else:
r = IPython.display.Latex(s, metadata = metadata)
IPython.display.display(r)
# return r
@staticmethod
def float_tostr(x, style = 0, bracket = True, force_float = False):
if x == numpy.inf or x == -numpy.inf or numpy.isnan(x):
if style & PsiOpts.STR_STYLE_LATEX:
if x == numpy.inf:
return r"\infty"
elif x == -numpy.inf:
return r"-\infty"
else:
return "?"
else:
return str(x)
ceps = 1e-10
if abs(x) <= ceps:
return "0"
elif abs(x - round(x)) <= ceps:
return str(int(round(x)))
else:
to_force_float = False
to_try_float = False
if force_float is True or (isinstance(force_float, int) and force_float >= 10) or PsiOpts.settings["str_float"]:
to_force_float = True
elif isinstance(force_float, int) and force_float >= 5:
to_try_float = True
if to_force_float:
# return str(x)
return ("{:." + str(PsiOpts.settings["str_float_dp"]) + "f}").format(x)
denom = 1
if to_try_float:
denom = PsiOpts.settings["max_denom_try"]
else:
denom = PsiOpts.settings["max_denom"]
frac = fractions.Fraction(abs(x)).limit_denominator(denom)
if to_try_float and abs(x - float(frac)) > ceps:
return ("{:." + str(PsiOpts.settings["str_float_dp"]) + "f}").format(x)
if x > 0:
if style & PsiOpts.STR_STYLE_LATEX_FRAC:
return "\\frac{" + str(frac.numerator) + "}{" + str(frac.denominator) + "}"
else:
if bracket:
return "(" + str(frac) + ")"
else:
return str(frac)
else:
if style & PsiOpts.STR_STYLE_LATEX_FRAC:
return "-\\frac{" + str(frac.numerator) + "}{" + str(frac.denominator) + "}"
else:
if bracket:
return "-(" + str(frac) + ")"
else:
return "-" + str(frac)
@staticmethod
def float_snap(x, denom = None, force = False, eps = None):
if denom is None:
denom = PsiOpts.settings["max_denom"]
if eps is None:
eps = PsiOpts.settings["eps"]
t = float(fractions.Fraction(x).limit_denominator(denom))
if force or abs(x - t) <= eps:
return t
return x
@staticmethod
def isconstzero(x, eps = None):
if eps is None:
eps = PsiOpts.settings["eps"]
if isinstance(x, Expr):
x = x.get_const()
if isinstance(x, bool):
return not x
if isinstance(x, (int, fractions.Fraction)):
return x == 0
if isinstance(x, float):
return abs(x) <= eps
return False
@staticmethod
def tostr_verbose(x):
if isinstance(x, float):
dp = PsiOpts.settings["verbose_float_dp"]
if dp is None:
return str(x)
return ("{:." + str(dp) + "f}").format(x)
elif isinstance(x, list):
return "[" + ", ".join(iutil.tostr_verbose(a) for a in x) + "]"
elif isinstance(x, tuple):
return "(" + ", ".join(iutil.tostr_verbose(a) for a in x) + ")"
else:
return str(x)
@staticmethod
def float_toz3(x):
if z3 is None:
return None
if abs(x - round(x)) <= 1e-10:
return z3.RealVal(int(round(x)))
else:
frac = fractions.Fraction(abs(x)).limit_denominator(
PsiOpts.settings["max_denom"])
return z3.Q(frac.numerator, frac.denominator)
@staticmethod
def latex_to_text(s, style=""):
if style == "graphviz" and PsiOpts.settings["graphviz_text_convert"] and pylatexenc:
# Reference: https://github.com/phfaist/pylatexenc/issues/36
cwalker = pylatexenc.latexwalker.get_default_latex_context_db()
cwalker.add_context_category("super_and_sub", specials=[
pylatexenc.macrospec.SpecialsSpec("^", args_parser=pylatexenc.macrospec.MacroStandardArgsParser("{")),
pylatexenc.macrospec.SpecialsSpec("_", args_parser=pylatexenc.macrospec.MacroStandardArgsParser("{")),
])
ctotext = pylatexenc.latex2text.get_default_latex_context_db()
ctotext.add_context_category("super_and_sub", specials=[
pylatexenc.latex2text.SpecialsTextSpec("^", simplify_repl="%s"),
pylatexenc.latex2text.SpecialsTextSpec("_", simplify_repl="%s"),
])
s = pylatexenc.latex2text.LatexNodes2Text(latex_context=ctotext).nodelist_to_text(
pylatexenc.latexwalker.LatexWalker(s, latex_context=cwalker).get_latex_nodes()[0])
if "" in s or "" in s:
s = "<" + s + ">"
return s
if PsiOpts.settings["use_pylatexenc"] and LatexNodes2Text:
s = LatexNodes2Text().latex_to_text(s)
return s
@staticmethod
def num_open_brackets(s):
return s.count("(") + s.count("[") + s.count("{") - (
s.count("}") + s.count("]") + s.count(")"))
@staticmethod
def split_comma_old(s, delim = ", "):
if not isinstance(s, str):
return [s]
t = s.split(delim)
c = ""
r = []
for a in t:
if c != "":
c += delim
c += a
if iutil.num_open_brackets(c) == 0:
r.append(c)
c = ""
if c != "":
r.append(c)
return r
@staticmethod
def split_comma(s):
if not isinstance(s, str):
return [s]
r = []
nopen = 0
for a in s:
if nopen == 0 and (a == "," or a == " "):
if len(r) == 0 or r[-1] != "":
r.append("")
else:
if a == "(" or a == "[" or a == "{":
nopen += 1
elif a == ")" or a == "]" or a == "}":
nopen -= 1
if len(r) == 0:
r.append("")
r[-1] += a
return r
@staticmethod
def get_count(counter_name = None, add = True):
if counter_name is None:
if add:
iutil.cur_count += 1
return iutil.cur_count
if counter_name not in iutil.cur_count_name:
iutil.cur_count_name[counter_name] = 0
if add:
iutil.cur_count_name[counter_name] += 1
return iutil.cur_count_name[counter_name]
@staticmethod
def gcd(a, b):
while b > 0:
a, b = b, a % b
return a
@staticmethod
def lcm(a, b):
return (a // iutil.gcd(a, b)) * b
@staticmethod
def distance_frommat(m, n0, n1):
if n0 == 0 and n1 == 0:
return 0.0
vis0 = [False] * n0
vis1 = [False] * n1
r = 0.0
for it in range(min(n0, n1)):
cmin = None
cmin0 = -1
cmin1 = -1
for i0 in range(n0):
if vis0[i0]:
continue
for i1 in range(n1):
if vis1[i1]:
continue
if cmin is None or cmin > m[i0][i1]:
cmin = m[i0][i1]
cmin0 = i0
cmin1 = i1
r += cmin
vis0[cmin0] = True
vis1[cmin1] = True
return (r + (n0 + n1 - min(n0, n1) * 2)) / (n0 + n1 - min(n0, n1))
@staticmethod
def distance(a, b):
n0 = len(a)
n1 = len(b)
m = [[x.distance(y) for y in b] for x in a]
return iutil.distance_frommat(m, n0, n1)
@staticmethod
def hasinstance(a, t):
if isinstance(a, t):
return True
if isinstance(a, (tuple, list)):
return any(iutil.hasinstance(x, t) for x in a)
@staticmethod
def convert_algtype(algtype):
if algtype is None:
return 0
if isinstance(algtype, str):
if algtype == "":
return 0
if algtype in ("semigroup", "monoid"):
return AlgType.SEMIGROUP
if algtype == "group":
return AlgType.GROUP
if algtype == "abelian":
return AlgType.ABELIAN
if algtype in ("torsionfree", "vector", "real"):
return AlgType.REAL
raise ValueError("Algebraic structure \"" + algtype + "\" not supported. Options are "
+ ", ".join("\"" + x + "\"" for x in ("semigroup", "group", "abelian", "torsionfree", "vector", "real")))
return 0
return algtype
@staticmethod
def convert_str_style(style):
if isinstance(style, str):
style = style.lower()
if style == "standard" or style == "std":
return PsiOpts.settings["str_style_std"]
elif style == "aitip":
return PsiOpts.settings["str_style_std"] & ~(PsiOpts.STR_STYLE_MARKOV) | PsiOpts.STR_STYLE_REM_SYMBOLS
elif style == "psitip" or style == "code":
return PsiOpts.settings["str_style_repr"]
elif style == "latex":
return PsiOpts.settings["str_style_latex"] | PsiOpts.STR_STYLE_LATEX_ARRAY | PsiOpts.STR_STYLE_LATEX_FRAC
elif style == "latex_noarray":
return PsiOpts.settings["str_style_latex"] | PsiOpts.STR_STYLE_LATEX_FRAC
elif style == "graphviz":
if PsiOpts.settings["graphviz_text_convert"]:
return iutil.convert_str_style("latex") | PsiOpts.STR_STYLE_GRAPHVIZ
else:
return iutil.convert_str_style("std")
else:
return style
@staticmethod
def reverse_eqnstr(eqnstr):
if eqnstr == "<=":
return ">="
elif eqnstr == ">=":
return "<="
elif eqnstr == "<":
return ">"
elif eqnstr == ">":
return "<"
else:
return eqnstr
@staticmethod
def eqnstr_style(eqnstr, style):
if eqnstr == "":
return ""
if style & PsiOpts.STR_STYLE_STANDARD:
if eqnstr == "==":
return PsiOpts.settings["str_eq"]
else:
return eqnstr
if style & PsiOpts.STR_STYLE_LATEX:
if eqnstr == "<=":
return "\\leq"
elif eqnstr == ">=":
return "\\geq"
elif eqnstr == "==":
return "="
elif eqnstr == "!=":
return "\\neq"
else:
return eqnstr
return eqnstr
@staticmethod
def op_str(eqnstr, a, b, isle = None):
if eqnstr == "<=":
return a <= b
elif eqnstr == ">=":
return a >= b
elif eqnstr == "<":
return a < b
elif eqnstr == ">":
return a > b
elif eqnstr == "==" or eqnstr == "=":
if isle is not None:
if isle:
return a == b
else:
return b == a
return a == b
elif eqnstr == "!=":
return a != b
elif eqnstr == ">>":
return a >> b
elif eqnstr == "<<":
return a << b
elif eqnstr == "+":
return a + b
elif eqnstr == "-":
return a - b
elif eqnstr == "*":
return a * b
elif eqnstr == "/":
return a / b
elif eqnstr == "**":
return a ** b
elif eqnstr == "//":
return a // b
elif eqnstr == "^":
return a ^ b
elif eqnstr == "&":
return a & b
elif eqnstr == "|":
return a | b
elif eqnstr == "%":
return a % b
return None
@staticmethod
def latex_len(s):
s = s.replace("\\left", "")
s = s.replace("\\right", "")
s = s.replace("\\hat", "")
s = s.replace("\\bar", "")
s = s.replace("\\tilde", "")
s = s.replace("\\displaystyle", "")
s = s.replace("\\!", "")
r = 0.0
i = 0
while i < len(s):
c = s[i]
if c == "\\":
r += 1
i += 1
while i < len(s) and s[i].isalpha():
i += 1
i -= 1
elif c in {"_", "^", "{", "}"}:
pass
elif c in {" ", ",", "(", ")"}:
r += 0.5
else:
r += 1
i += 1
return r
@staticmethod
def latex_split_line(strs, line_len, slstr = ""):
if line_len is None:
line_len = 100000000000
if isinstance(strs, str):
strs = [strs]
rlines = []
for s in strs:
r = []
c = 0
for i in range(len(s)):
sc = s[i:]
if (sc.startswith("+") or sc.startswith("-") or sc.startswith("<") or sc.startswith(">")
or sc.startswith("\\le") or sc.startswith("\\ge") or sc.startswith("=") or sc.startswith("\\neq")):
r.append(s[c: i])
c = i
r.append(s[c:])
lines = []
for x in r:
if lines and iutil.latex_len(lines[-1]) + iutil.latex_len(x) <= line_len:
lines[-1] += x
else:
lines.append(x)
rlines += lines
lines = rlines
if len(lines) == 0:
return ""
elif len(lines) == 1:
return lines[0]
else:
return "\\begin{array}{l}\n" + ("\\\\\n" + slstr).join(lines) + "\n\\end{array}"
@staticmethod
def strip_match(s, a, b):
if len(s) >= len(a) + len(b) and s.startswith(a) and s.endswith(b):
return s[len(a):-len(b)]
return None
@staticmethod
def latex_concat(style, strs):
nlstr = "\n"
if style & PsiOpts.STR_STYLE_LATEX:
nlstr = "\\\\\n"
if not style & PsiOpts.STR_STYLE_LATEX:
return nlstr.join(strs)
ma0 = "\\begin{align*}\n"
ma1 = "\\end{align*}\n"
strs = list(strs)
cmai = [False] * len(strs)
for i in range(len(strs)):
s1 = iutil.strip_match(strs[i], ma0, ma1)
if s1 is not None:
cmai[i] = True
strs[i] = s1
if any(cmai):
for i in range(len(strs)):
if not cmai[i]:
strs[i] = "& " + strs[i]
return ma0 + nlstr.join(strs) + ma1
return nlstr.join(strs)
@staticmethod
def str_list_concat(slists):
r = []
for slist in slists:
if isinstance(slist, str):
slist = [slist]
if r:
r += [",", " "]
r += slist
return r
@staticmethod
def meta_concat(metas):
if all(meta is None for meta in metas):
return None
metas = list(metas)
metas.sort(key = lambda meta: 0 if meta is None else meta.get("pf_note_priority", 0))
r = dict()
for meta in metas:
if meta is None:
continue
for key, value in meta.items():
if key in r:
if key == "pf_note":
r[key] = iutil.str_list_concat([r[key], value])
elif key == "pf_note_priority":
r[key] = r[key] + value
else:
r[key] = value
if "pf_note_priority" in r:
r["pf_note_priority"] *= 1.0 / len(metas)
return iutil.copy(r)
@staticmethod
def pf_note_str(s, style, note_color = None, add_space = 0, add_bracket = True):
if note_color is None:
note_color = PsiOpts.settings["proof_note_color"]
if isinstance(s, str):
s = [s]
if add_bracket:
s = [" " * add_space, "(", "since", " "] + s + [")"]
else:
s = [" " * add_space] + s
r = ""
if style & PsiOpts.STR_STYLE_LATEX:
if note_color is not None:
r += "{\\color{" + str(note_color) + "}{"
r += iutil.tostring_join(s, style)
if style & PsiOpts.STR_STYLE_LATEX:
if note_color is not None:
r += "}}"
return r
@staticmethod
def str_python_multiline(s):
s = str(s)
return "(\"" + "\\n\"\n\"".join(s.split("\n")) + "\")"
@staticmethod
def istensor(a):
return not isinstance(a, Comp) and hasattr(a, "shape")
# return isinstance(a, (numpy.array, IBaseArray)) or (torch is not None and isinstance(a, torch.Tensor))
@staticmethod
def hash_short(s):
s = str(s)
return hash(s) % 99991
@staticmethod
def z3_vardict(rvs, auxs, auxis, reals):
if z3 is None:
return None
n = len(rvs) + len(auxis)
r = {}
r["#n"] = n
s = z3.Solver()
if n:
x = z3.BitVec("TMPVARX", n)
y = z3.BitVec("TMPVARY", n)
H = z3.Function("H", x.sort(), z3.RealSort())
s.add(H(z3.BitVecVal(0, n)) == 0)
s.add(z3.ForAll([x, y], H(x) <= H(x|y)))
s.add(z3.ForAll([x, y], H(x) + H(y) >= H(x|y) + H(x&y)))
r["#H"] = H
else:
r["#H"] = None
r["#solver"] = s
for i, t in enumerate(rvs):
r[t.get_name()] = z3.BitVecVal(1 << i, n)
for t in auxs + auxis:
r[t.get_name()] = z3.BitVec(t.get_name(), n)
for t in reals:
r[t.get_name()] = z3.Real(t.get_name())
return r
@staticmethod
def get_solver(psolver = None):
csolver_list = [(x, False) for x in iutil.solver_list]
setting_solver = PsiOpts.settings["solver"]
if setting_solver is None:
setting_solver = ""
if isinstance(setting_solver, str):
setting_solver = [setting_solver]
csolver_list = [(x, True) for x in setting_solver if x != ""] + csolver_list
if psolver is not None:
csolver_list = [(psolver, True)] + csolver_list
warn_list = []
for s, iswarn in csolver_list:
sel = False
if s == "scipy" and (scipy is not None):
sel = True
elif s.startswith("pulp.") and (pulp is not None):
sel = True
elif s.startswith("pyomo.") and (pyo is not None):
sel = True
elif s.startswith("ortools.") and (ortools is not None):
sel = True
elif s == "z3" and (z3 is not None):
sel = True
if sel:
if warn_list:
warnings.warn("Solver " + ", ".join(warn_list) + " not found. Falling back to " + s +
". Use PsiOpts.setting(solver=\"" + s + "\") to stop this warning.", RuntimeWarning)
return s
else:
if iswarn:
warn_list.append(s)
if warn_list:
warnings.warn("Solver " + ", ".join(warn_list) + " not found.", RuntimeWarning)
return ""
@staticmethod
def pulp_get_solver(solver):
coptions = PsiOpts.settings["pulp_options"]
msg = PsiOpts.settings["verbose_solver"]
copt = solver[solver.index(".") + 1 :].upper()
if copt == "OTHER":
return iutil.pulp_solver
if copt in iutil.pulp_solvers:
return iutil.pulp_solvers[copt]
r = None
if copt == "GLPK":
#r = pulp.solvers.GLPK(msg = 0, options = coptions)
r = pulp.GLPK(msg = msg, timeLimit = PsiOpts.timer_left_sec(), options = coptions)
elif copt == "CBC" or copt == "PULP_CBC_CMD":
#r = pulp.solvers.PULP_CBC_CMD(options = coptions)
r = pulp.PULP_CBC_CMD(msg = msg, timeLimit = PsiOpts.timer_left_sec(), options = coptions)
elif copt == "GUROBI":
r = pulp.GUROBI(msg = msg, timeLimit = PsiOpts.timer_left_sec(), options = coptions)
elif copt == "CPLEX":
r = pulp.CPLEX(msg = msg, timeLimit = PsiOpts.timer_left_sec(), options = coptions)
elif copt == "MOSEK":
r = pulp.MOSEK(msg = msg, timeLimit = PsiOpts.timer_left_sec(), options = coptions)
elif copt == "CHOCO_CMD":
r = pulp.CHOCO_CMD(msg = msg, timeLimit = PsiOpts.timer_left_sec(), options = coptions)
iutil.pulp_solvers[copt] = r
return r
@staticmethod
def istorch(x):
return torch is not None and isinstance(x, torch.Tensor)
@staticmethod
def ensure_torch(x):
if hasattr(x, "get_x"):
x = x.get_x()
if iutil.istorch(x):
return x
return torch.tensor(x, dtype=torch.float64)
@staticmethod
def ensure_comp(x, strict = True, copy = False, ref = None):
if x is None:
return None
if isinstance(x, (BayesNet, FcnRelation)):
x = x.get_region()
if isinstance(x, Comp):
if copy:
return x.copy()
else:
return x
if isinstance(x, str):
if ref is not None:
r = ref.find(x)
if isinstance(r, Comp):
return r.copy()
return Comp.rv(x)
if isinstance(x, int):
return Comp.empty()
if isinstance(x, IBaseObj):
return x.allcomprv_noaux()
if x is False:
return Comp.empty()
r = Comp.empty()
for y in x:
t = iutil.ensure_comp(y, strict=strict, copy=True)
if t is not None:
r += t
return r
@staticmethod
def ensure_expr(x, strict = True, copy = False, ref = None):
if x is None:
return None
if isinstance(x, Expr):
if copy:
return x.copy()
else:
return x
if isinstance(x, Term):
return Expr.fromterm(x)
if isinstance(x, BayesNet):
if strict:
raise ValueError("Cannot convert BayesNet to Expr.")
x = x.get_region()
if isinstance(x, FcnRelation):
if strict:
raise ValueError("Cannot convert FcnRelation to Expr.")
x = x.get_region()
if isinstance(x, Region):
if strict:
raise ValueError("Cannot convert Region to Expr.")
return x.expr()
if isinstance(x, Comp):
if strict:
raise ValueError("Cannot convert Comp to Expr.")
return Expr.H(x)
if isinstance(x, ConcReal):
x = float(x)
if isinstance(x, (bool, int, float)):
return Expr.const(float(x))
if isinstance(x, fractions.Fraction):
return Expr.const(float(x))
if iutil.istensor(x):
return Expr.const(float(x))
if isinstance(x, str):
if ref is not None:
r = ref.find(x)
if isinstance(r, Expr):
return r.copy()
return Expr.parse(x)
if strict:
raise ValueError("Cannot convert " + str(type(x)) + " to Expr.")
r = Expr.zero()
for y in x:
t = iutil.ensure_expr(y, strict=strict, copy=True)
if t is not None:
r += t
return r
@staticmethod
def ensure_region(x, strict = True, copy = False):
if x is None:
return None
if isinstance(x, Region):
if copy:
return x.copy()
else:
return x
if isinstance(x, Comp):
raise ValueError("Cannot convert Comp to Region.")
# return Region.universe()
if isinstance(x, Expr):
raise ValueError("Cannot convert Expr to Region.")
# return x == 0
if isinstance(x, (bool, int, float)):
if x:
return Region.universe()
else:
return Region.empty()
if isinstance(x, (BayesNet, FcnRelation, PartialOrder)):
return x.get_region()
if isinstance(x, ConcModel):
return x.get_region(vals = True)
if isinstance(x, ProofObj):
return iutil.ensure_region(x.step_regions(), strict=strict, copy=copy)
if isinstance(x, str):
return Region.parse(x)
if isinstance(x, tuple):
if all(isinstance(y, Comp) for y in x):
# return BayesNet([x]).get_region()
return markov(*x)
if isinstance(x, (list, set)):
if all(isinstance(y, (tuple, Comp)) for y in x):
return BayesNet(list(x)).get_region()
r = Region.universe()
for y in x:
t = iutil.ensure_region(y, strict=strict, copy=True)
if t is not None:
r &= t
return r
@staticmethod
def type_coerce(a):
if any(isinstance(x, Region) for x in a):
return [iutil.ensure_region(x) for x in a]
if any(isinstance(x, Expr) for x in a):
return [iutil.ensure_expr(x) for x in a]
if any(isinstance(x, Comp) for x in a):
return [iutil.ensure_comp(x) for x in a]
return a
@staticmethod
def log(x):
if iutil.istorch(x):
return torch.log(x)
else:
return numpy.log(x)
@staticmethod
def xlogxoy(x, y):
if iutil.istorch(x) or iutil.istorch(y):
ceps_d = PsiOpts.settings["opt_eps_denom"]
return x * iutil.log((x + ceps_d) / (y + ceps_d))
else:
ceps = PsiOpts.settings["eps"]
if x <= ceps:
return 0.0
else:
return x * numpy.log(x / y)
@staticmethod
def xlogxoy2(x, y):
if iutil.istorch(x) or iutil.istorch(y):
ceps_d = PsiOpts.settings["opt_eps_denom"]
return x * (iutil.log((x + ceps_d) / (y + ceps_d)) ** 2)
else:
ceps = PsiOpts.settings["eps"]
if x <= ceps:
return 0.0
else:
return x * (numpy.log(x / y) ** 2)
@staticmethod
def sqrt(x):
if iutil.istorch(x):
return torch.sqrt(x)
else:
return numpy.sqrt(x)
@staticmethod
def product(x):
r = 1
for a in x:
r *= a
return r
@staticmethod
def bitcount(x):
r = 0
while x != 0:
x &= x - 1
r += 1
return r
@staticmethod
def strpad(*args):
r = ""
tgtlen = 0
for i in range(0, len(args), 2):
r += str(args[i])
if i + 1 < len(args):
tgtlen += int(args[i + 1])
while len(r) < tgtlen:
r += " "
return r
@staticmethod
def list_tostr(x, tuple_delim = ", ", list_delim = ", ", inden = 0):
r = " " * inden
if isinstance(x, list):
if len([a for a in x if isinstance(a, list) or isinstance(a, tuple)]) > 0:
r += "["
for i in range(len(x)):
if i == 0:
r += iutil.list_tostr(x[i], tuple_delim, list_delim, inden + 2)[inden + 1:]
else:
r += list_delim + "\n" + iutil.list_tostr(x[i], tuple_delim, list_delim, inden + 2)
r += " ]"
return r
else:
r += "[" + list_delim.join([iutil.list_tostr(a, tuple_delim, list_delim, 0) for a in x]) + "]"
return r
elif isinstance(x, tuple):
r += "(" + tuple_delim.join([iutil.list_tostr(a, tuple_delim, list_delim, 0) for a in x]) + ")"
return r
r += str(x)
#r += x.tostring()
return r
@staticmethod
def list_tostr_std(x):
return iutil.list_tostr(x, tuple_delim = ": ", list_delim = "; ")
@staticmethod
def list_iscomplex(x):
if not isinstance(x, list):
return True
for a in x:
if not isinstance(a, tuple):
return True
if len(a) != 2:
return True
if isinstance(a[1], list):
return True
return False
@staticmethod
def str_inden(s, ninden, spacestr = " ", slstr = ""):
return slstr + spacestr * ninden + s.replace("\n", "\n" + slstr + spacestr * ninden)
@staticmethod
def enum_partition(n):
def enum_partition_recur(mask):
if mask == 0:
return [[]]
r = []
for i in range(n):
if mask & (1 << i) != 0:
mask2 = mask - (1 << i)
while True:
mask3 = (1 << i) | mask2
r += [[mask3] + a
for a in enum_partition_recur(mask - mask3)]
if mask2 == 0:
break
mask2 = (mask2 - 1) & (mask - (1 << i))
break
return r
return enum_partition_recur((1 << n) - 1)
@staticmethod
def tsort(x):
"""Topological sort."""
n = len(x)
ninc = [0] * n
for i in range(n):
for j in range(n):
if x[i][j]:
ninc[j] += 1
cstack = [i for i in range(n) if ninc[i] == 0]
r = []
while len(cstack) > 0:
i = cstack.pop()
r.append(i)
for j in range(n):
if x[i][j]:
ninc[j] -= 1
if ninc[j] == 0:
cstack.append(j)
return r
@staticmethod
def iscyclic(x):
return len(iutil.tsort(x)) < len(x)
@staticmethod
def signal_type(x):
if isinstance(x, tuple) and len(x) > 0 and isinstance(x[0], str):
return x[0]
return ""
@staticmethod
def mhash(x):
if isinstance(x, list) or isinstance(x, tuple):
return hash(tuple(iutil.mhash(y) for y in x))
return hash(x)
@staticmethod
def list_unique(x):
r = []
s = set()
for a in x:
h = iutil.mhash(a)
if h not in s:
s.add(h)
r.append(a)
return r
@staticmethod
def list_sorted_unique(x):
x = sorted(x, key = lambda a: iutil.mhash(a))
return [x[i] for i in range(len(x)) if i == 0 or not x[i] == x[i - 1]]
@staticmethod
def list_interleave(x):
return [a for t in itertools.zip_longest(*x) for a in t if a is not None]
@staticmethod
def sumlist(x):
if isinstance(x, list) or isinstance(x, tuple):
return sum(iutil.sumlist(a) for a in x)
return x
@staticmethod
def split_number_text(s):
r = [""]
for c in s:
if c == " ":
if r[-1] != "":
r.append("")
else:
if c.isdigit():
if r[-1] != "" and not r[-1][-1].isdigit():
r.append("")
else:
if r[-1] != "" and r[-1][-1].isdigit():
r.append("")
r[-1] += c
return r
@staticmethod
def get_duration_ms(s):
t = iutil.split_number_text(s)
r = 0
strs_s = {"s", "sec", "secs", "second", "seconds"}
strs_m = {"m", "min", "mins", "minute", "minutes"}
strs_h = {"h", "hr", "hrs", "hour", "hours"}
strs_d = {"d", "day", "days"}
strs_w = {"w", "week", "weeks"}
strs_mon = {"mon", "month", "months"}
strs_y = {"y", "yr", "yrs", "year", "years"}
for i in range(len(t)):
if t[i].isdigit():
if i + 1 >= len(t):
r += int(t[i])
elif t[i + 1].lower() in strs_s:
r += int(t[i]) * 1000
elif t[i + 1].lower() in strs_m:
r += int(t[i]) * 1000 * 60
elif t[i + 1].lower() in strs_h:
r += int(t[i]) * 1000 * 60 * 60
elif t[i + 1].lower() in strs_d:
r += int(t[i]) * 1000 * 60 * 60 * 24
elif t[i + 1].lower() in strs_w:
r += int(t[i]) * 1000 * 60 * 60 * 24 * 7
elif t[i + 1].lower() in strs_mon:
r += int(t[i]) * 1000 * 60 * 60 * 24 * 7 * 30
elif t[i + 1].lower() in strs_y:
r += int(t[i]) * 1000 * 60 * 60 * 24 * 7 * 365
else:
r += int(t[i])
return r
@staticmethod
def remove_format(s):
return s.replace("{", "").replace("}", "").replace("_", "")
@staticmethod
def find_similarity(s, x):
if s == x:
return 10000 * 3
t = s.split("@@")
x2 = x.split("@@")
r = 0
for i in range(0, len(t), 2):
for j in range(0, len(x2), 2):
if x2[j] != "":
if x2[j] in t[i]:
r = max(r, 10000 + len(x2[j]) - len(t[i]))
else:
x2r = iutil.remove_format(x2[j])
tr = iutil.remove_format(t[i])
if x2r in tr:
r = max(r, 5000 + len(x2r) - len(tr))
return r
@staticmethod
def text_operation(s, ops, style = None, ensure_latex = True):
t = s.split("@@")
if ensure_latex and len(t) >= 1 and ("@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@") not in s:
return iutil.text_operation(s + "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + t[0], ops, style = style, ensure_latex = False)
if len(t) >= 2:
for i in range(0, len(t), 2):
t[i] = iutil.text_operation(t[i], ops, style = None if i == 0 else int(t[i - 1]), ensure_latex = False)
return "@@".join(t)
if isinstance(ops, tuple):
ops = [ops]
for op in ops:
if isinstance(op, tuple):
if op[0] == "set_sub":
s = iutil.latex_subsuperscript_convert(s, f_sub = lambda t: op[1])
elif op[0] == "set_super":
s = iutil.latex_subsuperscript_convert(s, f_super = lambda t: op[1])
elif op[0] in ("append_sub", "append_super"):
def f(t):
if len(t) == 0:
return op[1]
else:
return t + "," + op[1]
if op[0] == "append_sub":
s = iutil.latex_subsuperscript_convert(s, f_sub = f)
else:
s = iutil.latex_subsuperscript_convert(s, f_super = f)
elif op[0] in ("add_sub", "add_super"):
def f(t):
if len(t) == 0:
return str(op[1])
elif t.isdigit():
return str(int(t) + int(op[1]))
else:
return t + "," + str(op[1])
if op[0] == "add_sub":
s = iutil.latex_subsuperscript_convert(s, f_sub = f)
else:
s = iutil.latex_subsuperscript_convert(s, f_super = f)
elif op[0] == "append":
s += op[1]
return s
@staticmethod
def set_suffix_num(s, k, schar, replace_mode = "set", style = None, ensure_latex = True):
t = s.split("@@")
if ensure_latex and len(t) >= 1 and ("@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@") not in s:
return iutil.set_suffix_num(s + "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + t[0], k, schar, replace_mode, style = style, ensure_latex = False)
if len(t) >= 2:
for i in range(0, len(t), 2):
t[i] = iutil.set_suffix_num(t[i], k, schar, replace_mode, style = None if i == 0 else int(t[i - 1]), ensure_latex = False)
return "@@".join(t)
if replace_mode != "suffix":
def f_sub(v0):
if replace_mode == "append":
v0 += str(k)
elif replace_mode == "set":
v0 = str(k)
elif replace_mode == "add":
if v0.isdigit():
v0 = str(int(v0) + k)
else:
v0 += str(k)
return v0
return iutil.latex_subsuperscript_convert(s, f_sub=f_sub)
# s, v0 = iutil.break_subscript_latex(s)
# if replace_mode == "append":
# v0 += str(k)
# elif replace_mode == "set":
# v0 = str(k)
# elif replace_mode == "add":
# if v0.isdigit():
# v0 = str(int(v0) + k)
# else:
# v0 += str(k)
# if len(v0) > 1 and style == PsiOpts.STR_STYLE_LATEX:
# return s + "_{" + v0 + "}"
# else:
# return s + "_" + v0
# i = s.rfind(schar)
# if i >= 0:
# if replace_mode == "append":
# return s + str(k)
# elif replace_mode == "set":
# return s[:i] + schar + str(k)
# else:
# if s[i + 1 :].isdigit():
# if replace_mode == "add":
# return s[:i] + schar + str(int(s[i + 1 :]) + k)
# else:
# return s[:i] + schar + str(k)
return s + schar + str(k)
@staticmethod
def get_name_fromcat(s, style):
t = s.split("@@")
r = t[0]
for i in range(1, len(t) - 1, 2):
if int(t[i]) & style:
r = t[i + 1]
break
if style & PsiOpts.STR_STYLE_REM_SYMBOLS:
r = r.replace("_", "").replace("{", "").replace("}", "").replace("\\", "").replace("^", "")
return r
@staticmethod
def name_prefix_suffix(s, s_pre, s_suf):
t = s.split("@@")
for i in range(0, len(t), 2):
t[i] = s_pre + t[i] + s_suf
return "@@".join(t)
@staticmethod
def break_subscript_latex(s):
t = s.split("_")
if len(t) == 0:
return "", ""
if len(t) == 1:
return t[0], ""
v = "_".join(t[1:])
if v.startswith("{") and v.endswith("}"):
return t[0], v[1:-1]
return t[0], v
@staticmethod
def latex_break_last_group(s):
if len(s) == 0:
return "", ""
if s[-1] != "}":
return s[:-1], s[-1]
i = len(s)
lv = 0
while i > 0:
i -= 1
if s[i] == "{":
lv += 1
elif s[i] == "}":
lv -= 1
if lv == 0:
return s[:i], s[i+1:-1]
return "", s
@staticmethod
def latex_subsuperscript_convert(s, f_sub=None, f_super=None, f_body=None, insert_sub=True, insert_super=True):
r = ""
n_sub = 0
n_super = 0
while len(s):
t0, t1 = iutil.latex_break_last_group(s)
if n_sub == 0 and t0.endswith("_"):
if f_sub is not None:
t1 = f_sub(t1)
n_sub += 1
elif n_super == 0 and t0.endswith("^"):
if f_super is not None:
t1 = f_super(t1)
n_super += 1
else:
break
if t1 is not None:
if len(iutil.latex_break_last_group(t1)[0]):
t1 = "{" + t1 + "}"
r = t0[-1] + t1 + r
s = t0[:-1]
if f_body is not None:
s = f_body(s)
s = s + r
if insert_sub and n_sub == 0 and f_sub is not None:
t1 = f_sub("")
if t1 is not None:
if len(iutil.latex_break_last_group(t1)[0]):
t1 = "{" + t1 + "}"
s += "_" + t1
if insert_super and n_super == 0 and f_super is not None:
t1 = f_super("")
if t1 is not None:
if len(iutil.latex_break_last_group(t1)[0]):
t1 = "{" + t1 + "}"
s += "^" + t1
return s
@staticmethod
def is_single_term(x):
if isinstance(x, (int, float, tuple, list, IVar)):
return True
if isinstance(x, (Comp, Expr)):
return len(x) == 1
return False
@staticmethod
def fcn_name_maker(name, v, pname = None, lname = None, cropi = False, infix = False, latex_group = False, fcn_suffix = True, coeff_mul = False):
if not isinstance(v, list) and not isinstance(v, tuple):
v = [v]
r = ""
for style in [PsiOpts.STR_STYLE_STANDARD, PsiOpts.STR_STYLE_PSITIP, PsiOpts.STR_STYLE_LATEX]:
ccoeff_mul = False
if isinstance(coeff_mul, bool):
ccoeff_mul = coeff_mul
elif isinstance(coeff_mul, int):
ccoeff_mul = (style == coeff_mul)
else:
ccoeff_mul = (style in coeff_mul)
if style != PsiOpts.STR_STYLE_STANDARD:
r += "@@" + str(style) + "@@"
for i, a in enumerate(v):
apow = None
if isinstance(a, tuple):
apow = a[1]
a = a[0]
a_single_term = iutil.is_single_term(a)
def sel(*args):
for ca in args:
if ca is None:
continue
if isinstance(ca, tuple):
if apow is None or apow >= 0:
return ca[0]
else:
return ca[1]
else:
return ca
return None
if (i == 0) ^ infix:
if style & PsiOpts.STR_STYLE_STANDARD:
r += sel(name)
elif style & PsiOpts.STR_STYLE_PSITIP:
r += sel(pname, name)
elif style & PsiOpts.STR_STYLE_LATEX:
r += sel(lname, name)
if isinstance(a, str) and a == "#break":
break
if not infix:
if i == 0:
r += "("
if i:
r += ","
else:
if latex_group and style & PsiOpts.STR_STYLE_LATEX:
r += "{"
if not a_single_term:
r += "("
t = ""
if isinstance(a, (IVar, Comp, Term, Expr, Region)):
if (style & PsiOpts.STR_STYLE_STANDARD or style & PsiOpts.STR_STYLE_LATEX) and isinstance(a, Comp):
t = a.tostring(style = style, add_bracket = True)
else:
t = a.tostring(style = style)
elif isinstance(a, str):
if style & PsiOpts.STR_STYLE_PSITIP:
t = repr(a)
else:
t = str(a)
else:
t = str(a)
if apow is not None:
apow_o = apow
apow = iutil.float_tostr(apow)
if ccoeff_mul:
if apow == "1":
pass
elif apow == "-1":
if i == 0:
t = "-" + t
else:
t = iutil.float_tostr(abs(apow_o)) + t
if i == 0 and apow_o < 0:
t = "-" + t
else:
if style & PsiOpts.STR_STYLE_PSITIP:
t += "**" + apow
elif style & PsiOpts.STR_STYLE_LATEX:
t += "^"
if len(apow) != 1:
t += "{"
t += apow
if len(apow) != 1:
t += "}"
else:
t += "^" + apow
if cropi and isinstance(a, Term) and len(t) >= 3 and (t.startswith("I(") or t.startswith("H(")):
r += t[2:-1]
else:
r += t
if not infix:
if i == len(v) - 1:
r += ")"
else:
if not a_single_term:
r += ")"
if latex_group and style & PsiOpts.STR_STYLE_LATEX:
r += "}"
if fcn_suffix:
r += PsiOpts.settings["fcn_suffix"]
return r
@staticmethod
def add_subscript_latex(s, v):
def f_sub(t):
if len(t):
return t + "," + str(v)
else:
return str(v)
return iutil.latex_subsuperscript_convert(s, f_sub=f_sub)
# s, v0 = iutil.break_subscript_latex(s)
# if len(v0):
# v0 += "," + str(v)
# else:
# v0 = str(v)
# if len(v0) > 1:
# return s + "_{" + v0 + "}"
# else:
# return s + "_" + v0
@staticmethod
def copy(x):
if (x is None or isinstance(x, int) or isinstance(x, float)
or isinstance(x, str)):
return x
if isinstance(x, list):
return [iutil.copy(a) for a in x]
if isinstance(x, tuple):
return tuple(iutil.copy(a) for a in x)
if isinstance(x, dict):
return {iutil.copy(a): iutil.copy(b) for a, b in x.items()}
if isinstance(x, set):
return {iutil.copy(a) for a in x}
if isinstance(x, numpy.ndarray):
return numpy.copy(x)
if torch is not None and isinstance(x, torch.Tensor):
return x.detach().clone()
if isinstance(x, IBaseObj):
return x.copy()
return x
@staticmethod
def allcomprv(x):
if x is None or isinstance(x, (int, float, bool)):
return Comp.empty()
if isinstance(x, dict):
x = list(x.items())
if isinstance(x, (list, tuple, set)):
r = Comp.empty()
for a in x:
r += iutil.allcomprv(a)
return r
return x.allcomprv()
@staticmethod
def allcomprealvar_exprlist(x):
if x is None or isinstance(x, (int, float, bool)):
return ExprArray.make([])
if isinstance(x, dict):
x = list(x.items())
if isinstance(x, (list, tuple, set)):
r = ExprArray.make([])
for a in x:
r.iadd_noduplicate(iutil.allcomprealvar_exprlist(a))
return r
return x.allcomprealvar_exprlist()
@staticmethod
def allcompreal_exprlist(x):
if x is None or isinstance(x, (int, float, bool)):
return ExprArray.make([])
if isinstance(x, dict):
x = list(x.items())
if isinstance(x, (list, tuple, set)):
r = ExprArray.make([])
for a in x:
r.iadd_noduplicate(iutil.allcompreal_exprlist(a))
return r
return x.allcompreal_exprlist()
@staticmethod
def is_solvable_inteqn(A, b):
"""Solve A x = b for integer vector x.
"""
h = len(A)
if h == 0:
return True
w = len(A[0])
if w == 0:
return all(x == 0 for x in b)
A2 = smithnormalform.matrix.Matrix(h, w, [smithnormalform.z.Z(int(round(x))) for Ar in A for x in Ar])
prob = smithnormalform.snfproblem.SNFProblem(A2)
print(A2)
prob.computeSNF()
if not prob.isValid():
return False
b2 = smithnormalform.matrix.Matrix(h, 1, [smithnormalform.z.Z(int(round(x))) for x in b])
c = prob.S * b2
for i in range(h):
t = 0
if i < w:
t = prob.J.get(i, i).a
if t == 0:
if c.get(i, 0).a != 0:
return False
else:
if c.get(i, 0).a % t != 0:
return False
return True
# bv = numpy.array(b)
# Av = numpy.array(A)
# Sv = numpy.reshape(numpy.array([x.a for x in prob.S.elements]), (prob.S.h, prob.S.w))
# Tv = numpy.reshape(numpy.array([x.a for x in prob.T.elements]), (prob.T.h, prob.T.w))
# Jv = numpy.reshape(numpy.array([x.a for x in prob.J.elements]), (prob.J.h, prob.J.w))
# cv = Sv.dot(bv)
@staticmethod
def check_meta_subs_criteria(v0, v1, cobj = None):
mode = PsiOpts.settings["meta_subs_criteria"]
if mode is False or mode == "never":
return False
elif mode is True or mode == "always":
return True
else:
if (isinstance(v0, Comp) and isinstance(v1, Comp)) or (isinstance(v0, Expr) and isinstance(v1, Expr)):
if mode == "eqtype":
return True
elif mode == "eqtype_nonempty":
return (isinstance(v1, Comp) and not v1.isempty()) or (isinstance(v1, Expr) and not v1.iszero())
return False
def substitute(x, v0, v1):
if x is None:
return
if isinstance(x, (list, tuple)):
for a in x:
iutil.substitute(a, v0, v1)
elif isinstance(x, dict):
for a, b in x.items():
iutil.substitute(b, v0, v1)
elif isinstance(x, IBaseObj):
x.substitute(v0, v1)
def substitute_whole(x, v0, v1):
if x is None:
return
if isinstance(x, (list, tuple)):
for a in x:
iutil.substitute_whole(a, v0, v1)
elif isinstance(x, dict):
for a, b in x.items():
iutil.substitute_whole(b, v0, v1)
elif isinstance(x, IBaseObj):
x.substitute_whole(v0, v1)
@staticmethod
def isnumeric(x):
return isinstance(x, (int, float, fractions.Fraction))
@staticmethod
def str_join(x, delim = ""):
if isinstance(x, list) or isinstance(x, tuple):
return delim.join(iutil.str_join(a, delim) for a in x)
return str(x)
@staticmethod
def tostring_join(x, style, delim = "", nlstr = "\n"):
if isinstance(x, list) or isinstance(x, tuple):
r = ""
cmds = []
prev_comp = Comp.empty()
for a in x:
if r != "":
r += delim
if isinstance(a, str) and a.startswith("#CMD_"):
cmds.append(a)
else:
if isinstance(a, Comp):
if "#CMD_EXCLUDE_PREV" in cmds:
a = a - prev_comp
cmds.remove("#CMD_EXCLUDE_PREV")
prev_comp = a.copy()
r += iutil.tostring_join(a, style, delim, nlstr)
return r
# return delim.join(iutil.tostring_join(a, style, delim, nlstr) for a in x)
if hasattr(x, "tostring"):
return x.tostring(style = style)
if isinstance(x, str):
if style & PsiOpts.STR_STYLE_LATEX:
if x == " ":
return "\\,"
if all(a == " " for a in x):
return "\\;" * len(x)
if x == "(":
return "\\left("
if x == ")":
return "\\right)"
if x == "because" or x == "since":
return PsiOpts.settings["latex_because"]
if x == "therefore":
return PsiOpts.settings["latex_therefore"]
if x == "exists":
return PsiOpts.settings["latex_exists"]
if x == "for all":
return PsiOpts.settings["latex_forall"]
if x == "is typical":
return PsiOpts.settings["latex_is_typical"]
if x == "\n":
return nlstr
if style & PsiOpts.STR_STYLE_LATEX:
return "\\text{" + str(x) + "}"
return str(x)
@staticmethod
def bit_reverse(x, n):
r = 0
for i in range(n):
if x & (1 << i):
r += 1 << (n - 1 - i)
return r
@staticmethod
def bin_to_gray(x):
"""Binary to Gray code. From en.wikipedia.org/wiki/Gray_code
"""
return x ^ (x >> 1)
@staticmethod
def gray_to_bin(x):
"""Gray code to binary. From en.wikipedia.org/wiki/Gray_code
"""
m = x
while m:
m >>= 1
x ^= m
return x
@staticmethod
def gbsearch(fcn, num_iter = 30):
"""Binary search.
"""
lb = None
ub = None
for it in range(num_iter):
m = 0.0
if it == 0:
m = 0.0
elif ub is None:
m = max(lb * 2, 1.0)
elif lb is None:
m = min(ub * 2, -1.0)
else:
m = (lb + ub) * 0.5
if fcn(m):
ub = m
else:
lb = m
if ub is None:
return lb
if lb is None:
return ub
return (lb + ub) * 0.5
@staticmethod
def polygon_frompolar(poly, inf_value = 1e6):
"""
Convert polygon from polar representation to positive and negative portions.
Parameters
----------
poly : list
List of 3-tuples of vertices.
inf_value : float, optional
The value used as infinity. The default is 1e6.
Returns
-------
r : list
List of two polygons (positive and negative portions, both are lists of 2-tuples).
"""
r = [[], []]
x = [[0 if abs(a[0]) <= 1e-10 else 1 if a[0] > 0 else -1] + a[1:] for a in poly]
x.append(x[0])
for i in range(len(x) - 1):
if x[i][0] > 0:
r[0].append(x[i][1:])
elif x[i][0] < 0:
r[1].append([-y for y in x[i][1:]])
ray = None
if x[i][0] == 0:
ray = x[i]
elif x[i][0] * x[i + 1][0] < 0:
ray = [y0 + y1 for y0, y1 in zip(x[i], x[i+1])]
if ray is not None:
ray = ray[1:]
norm = max(abs(y) for y in ray)
ray = [y / norm * inf_value for y in ray]
r[0].append(ray)
r[1].append([-y for y in ray])
return r
@staticmethod
def is_in_conic_hull(pts, x):
ceps = PsiOpts.settings["eps_lp"]
n, k = pts.shape
if numpy.linalg.norm(x) <= ceps:
return True
if n == 0:
return False
model = ortools.linear_solver.pywraplp.Solver.CreateSolver("GLOP")
v = [model.NumVar(0, model.infinity(), "v" + str(i)) for i in range(n)]
for i in range(k):
model.Add(sum(v[j] * float(pts[j, i]) for j in range(n)) == float(x[i]))
model.Minimize(0)
status = model.Solve()
return status == ortools.linear_solver.pywraplp.Solver.OPTIMAL
@staticmethod
def is_in_conic_hull_interior(pts, c, rank=None):
if rank is None:
rank = numpy.linalg.matrix_rank(pts)
ceps = PsiOpts.settings["eps_lp"]
ubound = PsiOpts.settings["lp_ubound"]
n, k = pts.shape
# print("CONE", c)
# if sum(c[i] > ceps for i in range(n)) >= rank:
# return True
# orank = numpy.linalg.matrix_rank(pts[[i for i in range(n) if c[i] > ceps],:])
# orank = sum(c[i] > ceps for i in range(n))
# if orank >= rank:
# return True
ns = scipy.linalg.null_space(pts[[i for i in range(n) if c[i] > ceps],:])
is_out = [numpy.linalg.norm(pts[i,:].dot(ns)) > ceps for i in range(n)]
x = sum([pts[i,:] * c[i] for i in range(n)], numpy.zeros(k))
model = ortools.linear_solver.pywraplp.Solver.CreateSolver("GLOP")
# v = [model.NumVar(0, model.infinity(), "v" + str(i)) for i in range(n)]
v = [model.NumVar(0, ubound, "v" + str(i)) for i in range(n)]
for i in range(k):
model.Add(sum(v[j] * float(pts[j, i]) for j in range(n)) == float(x[i]))
# model.Minimize(sum(v[i] * c[i] for i in range(n)))
model.Maximize(sum(v[i] for i in range(n) if is_out[i]))
status = model.Solve()
# print(status)
# print([v[i].solution_value() for i in range(n)])
if status != ortools.linear_solver.pywraplp.Solver.OPTIMAL:
return False
# if all(abs(v[i].solution_value() - c[i]) <= ceps for i in range(n)):
# return False
# if numpy.linalg.matrix_rank(pts[[i for i in range(n) if c[i] + v[i].solution_value() > ceps],:]) > orank:
# return True
if any(is_out[i] and v[i].solution_value() > ceps for i in range(n)):
return True
return False
# print("INT", [(v[i].solution_value(), c[i]) for i in range(n)])
# return True
@staticmethod
def conic_hull_facets(pts, numyield=None, mode=None):
ceps = PsiOpts.settings["eps_lp"]
rnd = PsiOpts.get_random()
pts = numpy.array(pts)
for i in range(pts.shape[0] - 1, -1, -1):
if iutil.is_in_conic_hull(numpy.delete(pts, i, axis=0), pts[i,:]):
pts = numpy.delete(pts, i, axis=0)
# with numpy.printoptions(precision = 1):
# print(repr(pts))
rnd.shuffle(pts)
n, k = pts.shape
rank = numpy.linalg.matrix_rank(pts)
check_intermediate = True
n0 = n
# print(n, k, rank)
if rank <= 1:
return
avg = numpy.sum(pts, axis=0) / n
if mode == "two_connected":
pts2 = numpy.array(pts)
for i in range(pts.shape[0] - 1, 0, -1):
if iutil.is_in_conic_hull_interior(pts, [i2 == 0 or i2 == i for i2 in range(n)], rank):
pts2 = numpy.delete(pts2, i, axis=0)
pts = pts2
n, k = pts.shape
check_intermediate = False
# print(n0, n, k, rank)
sel = [False] * n
cur_numyield = [0]
def recur(i):
# print(i, [i2 for i2 in range(n) if sel[i2]])
if numyield is not None and cur_numyield[0] >= numyield:
return
sumsel = sum(sel)
if i >= n or sumsel + n - i < rank - 1 or sumsel > rank - 1:
return
if PsiOpts.is_timer_ended():
return
yield from recur(i + 1)
sel[i] = True
sumsel += 1
if sumsel == rank - 1:
ns = scipy.linalg.null_space(numpy.vstack(tuple(pts[i2,:] for i2 in range(n) if sel[i2])))
mw = None
mws = ceps
for j in range(ns.shape[1]):
w = ns[:,j]
ws = w.dot(avg)
if ws < 0:
ws = -ws
w = -w
if ws > mws:
mws = ws
mw = w
if mw is not None:
if all(mw.dot(pts[i2,:]) >= -ceps for i2 in range(n) if not sel[i2]):
yield mw
cur_numyield[0] += 1
elif sumsel == 1 or (sumsel < rank - 1 and not check_intermediate) or not iutil.is_in_conic_hull_interior(pts, sel, rank):
yield from recur(i + 1)
sel[i] = False
if mode == "two_connected":
sel[0] = True
yield from recur(1)
else:
yield from recur(0)
class MHashList:
def __init__(self, x = None, ha = None):
if x is None:
self.x = []
else:
self.x = x
if ha is None:
self.ha = []
else:
self.ha = ha
def add(self, y):
h = iutil.mhash(y)
i = 0
while i < len(self.x):
if self.ha[i] == h:
self.x.pop(i)
self.ha.pop(i)
else:
i += 1
self.x.append(y)
self.ha.append(h)
def clear(self):
self.x[:] = []
self.ha[:] = []
def __len__(self):
return len(self.x)
def __getitem__(self, key):
return self.x[key]
def copy(self):
return MHashList(self.x[:], self.ha[:])
class MHashSet:
def __init__(self, x = None, s = None):
if x is None:
self.x = []
else:
self.x = x
if s is None:
self.s = set()
else:
self.s = s
def add(self, y):
h = iutil.mhash(y)
if h in self.s:
return False
self.x.append(y)
self.s.add(h)
return True
def clear(self):
self.x[:] = []
self.s.clear()
def __iadd__(self, other):
for y in other:
self.add(y)
return self
def __len__(self):
return len(self.x)
def __getitem__(self, key):
return self.x[key]
#return self.x[len(self.x) - 1 - key]
def copy(self):
return MHashSet(self.x[:], self.s.copy())
def __eq__(self, other):
return self.s == other.s
def __ne__(self, other):
return self.s != other.s
def __hash__(self):
return hash(frozenset(self.s))
def fcn_substitute(fcn):
@functools.wraps(fcn)
def wrapper(cself, *args, **kwargs):
clist = []
def fcn2(cself, key, val):
if isinstance(key, str):
key = cself.find(key)
if not ((isinstance(key, Comp) and not key.isempty())
or (isinstance(key, Expr) and not key.iszero())):
return
if isinstance(val, str):
if val == "":
val = None
else:
val = cself.find(val)
if not ((isinstance(val, Comp) and not val.isempty())
or (isinstance(val, Expr) and not val.iszero())):
return
if val is None:
if isinstance(key, Comp):
val = Comp.empty()
elif isinstance(key, Expr):
val = Expr.zero()
else:
return
if isinstance(key, Expr) and not isinstance(val, Comp):
t = iutil.ensure_expr(val)
if t is not None:
val = t
# fcn(cself, key, val)
clist.append([key, val])
i = 0
while i < len(args):
if isinstance(args[i], dict):
for key, val in args[i].items():
fcn2(cself, key, val)
i += 1
elif isinstance(args[i], CompArray):
for key, val in args[i].to_dict().items():
fcn2(cself, key, val)
i += 1
elif isinstance(args[i], list):
for key, val in args[i]:
fcn2(cself, key, val)
i += 1
elif i + 1 < len(args):
fcn2(cself, args[i], args[i + 1])
i += 2
else:
i += 1
for key, val in kwargs.items():
fcn2(cself, key, val)
# print(clist)
for i, v in enumerate(clist):
if i < len(clist) - 1:
if isinstance(v[0], Expr):
v.insert(1, Expr.real("#TMPSUB" + str(i)))
else:
v.insert(1, Comp.rv("#TMPSUB" + str(i)))
# print(clist)
for it in range(2):
for i, v in enumerate(clist):
if it + 1 < len(v):
# print(str(v[it]) + " " + str(v[it + 1]))
fcn(cself, v[it], v[it + 1])
return wrapper
def latex_postprocess(fcn):
@functools.wraps(fcn)
def wrapper(*args, **kwargs):
r = fcn(*args, **kwargs)
color = PsiOpts.settings["latex_color"]
if color is not None:
r = "{\\color{" + str(color) + "}{" + r + "}}"
return r
return wrapper
def fcn_list_to_list(fcn):
@functools.wraps(fcn)
def wrapper(*args, **kwargs):
islist = False
maxlen = -1
maxshape = tuple()
for a in itertools.chain(args, kwargs.values()):
if CompArray.isthis(a) or ExprArray.isthis(a):
islist = True
# maxlen = max(maxlen, len(a))
if len(a) > maxlen:
maxlen = len(a)
if isinstance(a, list):
maxshape = (len(a),)
else:
maxshape = a.shape
if not islist:
return fcn(*args, **kwargs)
r = []
for i in range(maxlen):
targs = []
for a in args:
if CompArray.isthis(a):
if i < len(a):
targs.append(a[i])
else:
targs.append(Comp.empty())
elif ExprArray.isthis(a):
if i < len(a):
targs.append(a[i])
else:
targs.append(Expr.zero())
else:
targs.append(a)
tkwargs = dict()
for key, a in kwargs.items():
if CompArray.isthis(a):
if i < len(a):
tkwargs[key] = a[i]
else:
tkwargs[key] = Comp.empty()
elif ExprArray.isthis(a):
if i < len(a):
tkwargs[key] = a[i]
else:
tkwargs[key] = Expr.zero()
else:
tkwargs[key] = a
r.append(fcn(*targs, **tkwargs))
if len(r) > 0:
if isinstance(r[0], Region):
return alland(r)
elif isinstance(r[0], Comp):
return CompArray(r, maxshape)
elif isinstance(r[0], Expr):
return ExprArray(r, maxshape)
else:
return r
else:
return None
return wrapper
class PsiRec:
num_lpprob = 0
class IVarType:
NIL = 0
RV = 1
REAL = 2
class AlgType:
NIL = 0
SEMIGROUP = 1
GROUP = 2
ABELIAN = 3
REAL = 4
class IBaseObj:
"""Base class of objects
"""
def __init__(self):
pass
def add_meta(self, key, value):
if self.meta is None:
self.meta = {}
self.meta[key] = value
return self
def get_meta(self, key):
if self.meta is None:
return None
if key not in self.meta:
return None
return self.meta[key]
def remove_meta(self, key):
if self.meta is None:
return self
self.meta.pop(key, None)
return self
@latex_postprocess
def _latex_(self):
return ""
def latex(self, skip_simplify = False):
"""LaTeX code
"""
if skip_simplify:
r = ""
with PsiOpts(repr_simplify = False):
r = self._latex_()
return r
return self._latex_()
def display(self, skip_simplify = False, **kwargs):
"""IPython display
"""
if kwargs:
r = None
with PsiOpts(**{key: val for key, val in kwargs.items()}):
r = self.display(skip_simplify = skip_simplify)
return r
iutil.display_latex(self.latex(skip_simplify = skip_simplify))
def print(self, **kwargs):
"""Print this object
"""
if kwargs:
with PsiOpts(**{key: val for key, val in kwargs.items()}):
print(self)
return
print(self)
def graphviz_text(self, **kwargs):
if PsiOpts.settings["graphviz_text_convert"]:
r = self.latex()
r = iutil.latex_to_text(r, style="graphviz")
return r
return str(self)
def display_bool(self, s = "{region} \\;\\mathrm{{is}}\\;\\mathrm{{{truth}}}", skip_simplify = True):
"""IPython display, show bool
"""
iutil.display_latex(s.format(region = self.latex(skip_simplify = skip_simplify),
truth = str(bool(self))))
def display_truth_value(self, s = "{region} \\;\\mathrm{{is}}\\;\\mathrm{{{truth}}}", skip_simplify = True):
"""IPython display, show truth value
"""
iutil.display_latex(s.format(region = self.latex(skip_simplify = skip_simplify),
truth = str(self.truth_value())))
def substituted(self, *args, **kwargs):
"""Substitute variable v0 by v1 (v1 can be compound), return result"""
r = self.copy()
r.substitute(*args, **kwargs)
return r
def substituted_whole(self, *args, **kwargs):
"""Substitute variable v0 by v1 (v1 can be compound), return result"""
r = self.copy()
r.substitute_whole(*args, **kwargs)
return r
def subs(self, *args, **kwargs):
"""Alias of substituted_whole
"""
return self.substituted_whole(*args, **kwargs)
def subs_aux(self, *args, **kwargs):
"""Alias of substituted_aux
"""
return self.substituted_aux(*args, **kwargs)
def substituted_swap(self, x, y = None):
"""Swap two variablex x, y
"""
if y is not None:
x = [(x, y)]
return self.substituted_whole(list(x) + [(b, a) for a, b in x])
def subs_swap(self, *args, **kwargs):
"""Alias of substituted_swap
"""
return self.substituted_swap(*args, **kwargs)
def reg(self, *args, **kwargs):
"""Alias of get_region
"""
return self.get_region(*args, **kwargs)
def ensure_region(self, *args, **kwargs):
if isinstance(self, Region):
return self
return self.get_region(*args, **kwargs)
def defn(self, *args, **kwargs):
"""Alias of definition
"""
return self.definition(*args, **kwargs)
def sim(self, *args, **kwargs):
"""Alias of simplified
"""
return self.simplified(*args, **kwargs)
def __copy__(x):
return x.copy()
def __deepcopy__(x, memo):
return x.copy()
@staticmethod
def set_repr_latex(enabled):
hasa = hasattr(IBaseObj, "_repr_latex_")
if enabled and not hasa:
setattr(IBaseObj, "_repr_latex_", lambda s: "$" + s.latex() + "$")
if not enabled and hasa:
delattr(IBaseObj, "_repr_latex_")
def simplified_truth(self, reg = None, quick = False):
truth = PsiOpts.settings["truth"]
if truth is not None:
if reg is None:
reg = truth.copy()
else:
reg = reg & truth
with PsiOpts(truth = None):
if isinstance(self, Expr):
return self.simplified(reg = reg, bnet = reg.get_bayesnet())
else:
if quick:
return self.simplified_quick(reg = reg)
else:
return self.simplified(reg = reg)
if quick:
return self.simplified_quick(reg = reg)
else:
return self.simplified(reg = reg)
def mark_rv(self, x, *args, **kwargs):
y = x.copy()
y.mark(*args, **kwargs)
for a, b in zip(x, y):
self.substitute(a, b)
return self
def marked_rv(self, x, *args, **kwargs):
r = self.copy()
r.mark_rv(x, *args, **kwargs)
return r
def allcomp(self):
index = IVarIndex()
self.record_to(index)
return index.comprv + index.compreal
def allcomprv(self):
index = IVarIndex()
self.record_to(index)
return index.comprv
def getauxall(self):
return Comp.empty()
def allcomprv_noaux(self):
return self.allcomprv() - self.getauxall()
def allcompreal(self):
index = IVarIndex()
self.record_to(index)
return index.compreal
def allcompreal_exprlist(self):
t = self.allcompreal()
return ExprArray.make([Expr.real(v) for v in t])
def allcomprealvar_exprlist(self):
t = self.allcomprealvar()
return ExprArray.make([Expr.real(v) for v in t])
def find(self, *args):
return self.allcomp().find(*args)
def simplify_regterm(self, reg = None):
regterms = {}
self.regtermmap(regterms, False)
for (name, x) in regterms.items():
if isinstance(x, Term):
y = x.copy()
t = y.simplify_regterm_expr(reg = reg)
if t is not None:
self.substitute(Expr.fromterm(x), t)
@property
def rvs(self):
return self.allcomprv_noaux()
@property
def reals(self):
return self.allcomprealvar_exprlist()
PsiOpts.setting(repr_latex = True) # Latex display default on
class IVar(IBaseObj):
"""Random variable or real variable
Do NOT use this class directly. Use Comp instead
"""
def __init__(self, vartype, name, reg = None, reg_det = False, markers = None, algtype = 0, alglist = None):
self.vartype = vartype
self.name = name
self.reg = reg
self.reg_det = reg_det
self.markers = markers
self.algtype = algtype
self.alglist = alglist
@staticmethod
def rv(name):
return IVar(IVarType.RV, name)
@staticmethod
def index(name):
r = IVar(IVarType.RV, name)
r.markers = [("index_shift", 0)]
return r
@staticmethod
def real(name):
return IVar(IVarType.REAL, name)
@staticmethod
def eps():
return IVar(IVarType.REAL, "EPS")
@staticmethod
def one():
return IVar(IVarType.REAL, "ONE")
@staticmethod
def inf():
return IVar(IVarType.REAL, "INF")
def isrealvar(self):
return self.vartype == IVarType.REAL and self.name != "ONE" and self.name != "EPS" and self.name != "INF"
def get_marker_key(self, key):
if self.markers is not None:
for v, w in reversed(self.markers):
if v == key:
return w
return None
def tostring(self, style = 0):
r = ""
if style & PsiOpts.STR_STYLE_LATEX:
if self.name == "EPS":
r = PsiOpts.settings["latex_eps"]
elif self.name == "INF":
r = PsiOpts.settings["latex_infty"]
if r == "":
r = iutil.get_name_fromcat(self.name, style)
shift = self.get_marker_key("index_shift")
if shift is not None and shift != 0:
if shift > 0:
r += "+"
r += str(shift)
if style & PsiOpts.STR_STYLE_LATEX and not style & PsiOpts.STR_STYLE_GRAPHVIZ:
col = self.get_marker_key("color")
if col is not None:
r = "{\\color{" + str(col) + "}{" + r + "}}"
if style & PsiOpts.STR_STYLE_LATEX and style & PsiOpts.STR_STYLE_GRAPHVIZ:
r = iutil.latex_to_text(r, style="graphviz")
return r
def __str__(self):
return self.tostring(PsiOpts.settings["str_style"])
def __repr__(self):
return self.tostring(PsiOpts.settings["str_style_repr"])
@latex_postprocess
def _latex_(self):
return self.tostring(iutil.convert_str_style("latex"))
def __hash__(self):
return hash(self.name)
def __eq__(self, other):
return self.name == other.name
def copy(self):
return IVar(self.vartype, self.name, None if self.reg is None else self.reg.copy(),
self.reg_det, None if self.markers is None else self.markers[:],
self.algtype, iutil.copy(self.alglist))
def copy_noreg(self):
return IVar(self.vartype, self.name, None,
False, None if self.markers is None else self.markers[:],
self.algtype, iutil.copy(self.alglist))
@staticmethod
def word_normalize(algtype, x):
if algtype in (AlgType.ABELIAN, AlgType.REAL):
x.sort(key = lambda a: str(a[0]))
@staticmethod
def word_combine_to(algtype, x, y):
if algtype in (AlgType.ABELIAN, AlgType.REAL):
for b, bc in y:
for i in range(len(x)):
if x[i][0] == b:
x[i] = (x[i][0], x[i][1] + bc)
break
else:
x.append((iutil.copy(b), bc))
elif algtype in (AlgType.GROUP, AlgType.SEMIGROUP):
y = iutil.copy(y)
while x and y and x[-1][0] == y[0][0] and x[-1][1] + y[0][1] == 0:
x.pop()
y.pop(0)
if x and y and x[-1][0] == y[0][0]:
x[-1] = (x[-1][0], x[-1][1] + y[0][1])
y.pop(0)
x += y
x[:] = [(a, ac) for a, ac in x if ac != 0]
@staticmethod
def word_combine(algtype, x, y):
x = iutil.copy(x)
IVar.word_combine_to(algtype, x, y)
return x
@staticmethod
def word_pow(algtype, x, p):
if p == 0:
return []
x = iutil.copy(x)
if p == 1:
return x
if algtype in (AlgType.ABELIAN, AlgType.REAL):
return [(a, ac * p) for a, ac in x]
if p < 0:
x = [(a, -ac) for a, ac in reversed(x)]
p = -p
r = list(x)
for i in range(p - 1):
IVar.word_combine_to(algtype, r, x)
return r
@staticmethod
def word_name(algtype, x):
if len(x) == 0:
return iutil.fcn_name_maker("id", ["#break"], lname = PsiOpts.settings["latex_group_id"])
if PsiOpts.settings["latex_group_additive"]:
return iutil.fcn_name_maker("*", [(a, ac) if ac != 1 else a for a, ac in x],
lname = (PsiOpts.settings["latex_group_add"] + " ", PsiOpts.settings["latex_group_minus"] + " "),
infix = True, coeff_mul = PsiOpts.STR_STYLE_LATEX)
else:
return iutil.fcn_name_maker("*", [(a, ac) if ac != 1 else a for a, ac in x],
lname = PsiOpts.settings["latex_group_mul"] + " ", infix = True)
def get_alglist(self):
if self.algtype == 0:
return []
if self.alglist is None:
return [[(self, 1)]]
return self.alglist
def __mul__(self, other):
if isinstance(other, (int, float, fractions.Fraction)) and other == 1:
return self.copy()
if self.algtype == 0:
raise ValueError("Cannot multiply non-group-valued variables.")
return
if self.algtype != other.algtype:
raise ValueError("Cannot multiply variables with different types.")
return
r = self.copy()
r.alglist = []
for x in self.get_alglist():
for y in other.get_alglist():
r.alglist.append(IVar.word_combine(self.algtype, x, y))
if PsiOpts.settings["alg_normalize"]:
for a in r.alglist:
IVar.word_normalize(r.algtype, a)
r.name = IVar.word_name(r.algtype, r.alglist[0])
return r
def __pow__(self, other):
if self.algtype == 0:
raise ValueError("Cannot take power of non-group-valued variables.")
return
r = self.copy()
r.alglist = iutil.copy(r.get_alglist())
for x in r.alglist:
x[:] = IVar.word_pow(r.algtype, x, other)
if PsiOpts.settings["alg_normalize"]:
for a in r.alglist:
IVar.word_normalize(r.algtype, a)
r.name = IVar.word_name(r.algtype, r.alglist[0])
return r
def __truediv__(self, other):
return self * (other ** -1)
def __rtruediv__(self, other):
if isinstance(other, (int, float, fractions.Fraction)) and other == 1:
return self ** -1
return other * (self ** -1)
def __rmul__(self, other):
return self * other
@staticmethod
def word_recordto(algtype, index, x):
for a, ac in x:
index.record(Comp([a]))
@staticmethod
def word_toarray(algtype, index, x):
r = numpy.zeros(len(index.comprv))
for a, ac in x:
r[index.get_index(a)] += ac
return r
@staticmethod
def word_tointlist(algtype, index, x):
r = [0] * len(index.comprv)
for a, ac in x:
r[index.get_index(a)] += ac
return r
@staticmethod
def word_toid(index, x):
return [(index.get_index(a), ac) for a, ac in x]
@staticmethod
def word_hash(x):
return hash(tuple(x))
@staticmethod
def word_complexity(x):
return sum(abs(ac) + 1 for a, ac in x)
@staticmethod
def word_checkfcn(algtype, xs, y):
index = IVarIndex()
IVar.word_recordto(algtype, index, y)
if len(index.comprv) == 0:
return True
if len(xs) == 0:
return False
for x in xs:
IVar.word_recordto(algtype, index, x)
if algtype == AlgType.REAL:
m = numpy.vstack([IVar.word_toarray(algtype, index, x) for x in xs])
m2 = numpy.vstack([m, IVar.word_toarray(algtype, index, y)])
return numpy.linalg.matrix_rank(m2) <= numpy.linalg.matrix_rank(m)
# elif algtype == AlgType.ABELIAN:
# A = [IVar.word_tointlist(algtype, index, x) for x in xs]
# return iutil.is_solvable_inteqn([[A[j][i] for j in range(len(A))] for i in range(len(A[0]))],
# IVar.word_tointlist(algtype, index, y))
elif algtype in (AlgType.GROUP, AlgType.SEMIGROUP, AlgType.ABELIAN):
xs = [IVar.word_toid(index, x) for x in xs]
y = IVar.word_toid(index, y)
wlist = []
vis = set()
nit = [0]
maxnit = (len(xs) + 1) * PsiOpts.settings["alg_group_nit"]
found = [False]
start_inv = (algtype == AlgType.GROUP or algtype == AlgType.ABELIAN)
push_inv = start_inv
push_left = (algtype == AlgType.GROUP or algtype == AlgType.SEMIGROUP)
def trypush(a):
if found[0]:
return False
if start_inv:
if len(a) == 0:
found[0] = True
return False
else:
if a == y:
found[0] = True
return False
if nit[0] > maxnit:
return False
ahash = tuple(a) # IVar.word_hash(a)
if ahash in vis:
return False
heapq.heappush(wlist, (IVar.word_complexity(a), a))
vis.add(ahash)
nit[0] += 1
return True
if start_inv:
trypush(IVar.word_pow(algtype, y, -1))
else:
trypush([])
while wlist:
if found[0]:
break
acomp, a = heapq.heappop(wlist)
for x in xs:
if found[0]:
break
trypush(IVar.word_combine(algtype, a, x))
if push_left:
trypush(IVar.word_combine(algtype, x, a))
if push_inv:
xinv = IVar.word_pow(algtype, x, -1)
trypush(IVar.word_combine(algtype, a, xinv))
if push_left:
trypush(IVar.word_combine(algtype, xinv, a))
return found[0]
return False
@staticmethod
def word_fcnrelation(algtype, xs):
fcns = FcnRelation()
index = IVarIndex()
for x in xs:
IVar.word_recordto(algtype, index, x)
xmasks = []
for x in xs:
index2 = IVarIndex()
IVar.word_recordto(algtype, index2, x)
xmasks.append(index.get_mask(index2.comprv))
n = len(xs)
for i in range(n):
for jmask in igen.subset_mask((1 << n) - 1 - (1 << i)):
if algtype == AlgType.REAL:
if iutil.bitcount(jmask) > len(index.comprv):
continue
jmask2 = 0
for k, a in enumerate(xmasks):
if jmask & (1 << k):
jmask2 |= a
if jmask2 | xmasks[i] != jmask2:
continue
if fcns.check_fcn(jmask, 1 << i):
continue
if IVar.word_checkfcn(algtype, [a for k, a in enumerate(xs) if (1 << k) & jmask], xs[i]):
fcns.add_fcn(jmask, 1 << i)
fcns.simplify()
return fcns
class Comp(IBaseObj):
"""Compound random variable or real variable
"""
def __init__(self, varlist):
self.varlist = varlist
@staticmethod
def empty():
"""
The empty random variable.
Returns
-------
Comp
"""
return Comp([])
@staticmethod
def rv(name):
"""
Random variable.
Parameters
----------
name : str
Name of the random variable.
Returns
-------
Comp
"""
return Comp([IVar(IVarType.RV, name)])
@staticmethod
def index(name):
"""
Random variable used as an index.
Parameters
----------
name : str
Name of the random variable.
Returns
-------
Comp
"""
return Comp([IVar.index(name)])
@staticmethod
def rv_reg(a, reg, reg_det = False):
r = a.copy_noreg()
for i in range(len(r.varlist)):
r.varlist[i].reg = reg.copy()
r.varlist[i].reg_det = reg_det
return r
#return Comp([IVar(IVarType.RV, str(a), reg.copy(), reg_det)])
@staticmethod
def real(name):
return Comp([IVar(IVarType.REAL, name)])
@staticmethod
def array(name, st, en = None):
if isinstance(st, int):
if en is None:
en = st
st = 0
st = range(st, en)
t = []
for i in st:
istr = str(i)
s = name + "_" + istr
s += "@@" + str(PsiOpts.STR_STYLE_LATEX)
s += "@@" + iutil.add_subscript_latex(name, istr)
t.append(IVar(IVarType.RV, s))
return Comp(t)
def get_name(self):
if len(self.varlist) == 0:
return ""
return self.varlist[0].name
def find(self, *args):
args = [x for b in args for x in iutil.split_comma(b)]
r = []
numrv = 0
numreal = 0
for carg0 in args:
carg0 = str(carg0)
cmaxa0 = None
for carg in carg0.split("+"):
cmax = 0
cmaxa = None
for a in self.varlist:
t = iutil.find_similarity(a.name, carg.strip())
if t > cmax:
cmax = t
cmaxa = Comp([a.copy()])
if cmaxa is not None:
if cmaxa0 is None:
cmaxa0 = cmaxa
else:
cmaxa0 += cmaxa
if cmaxa0 is None:
continue
if cmaxa0.get_type() == IVarType.RV:
r.append(cmaxa0)
numrv += 1
elif cmaxa0.get_type() == IVarType.REAL:
r.append(Expr.fromcomp(cmaxa0))
numreal += 1
if len(r) == 0:
return None
if len(r) == 1:
return r[0]
if numreal == 0 and numrv >= 2:
return CompArray(r)
if numrv == 0 and numreal >= 2:
return ExprArray(r)
return r
def get_type(self):
if len(self.varlist) == 0:
return IVarType.NIL
return self.varlist[0].vartype
def allcomp(self):
return self.copy()
def complexity(self):
return len(self.varlist)
def sorting_priority(self):
return self.complexity()
def sorting_tuple(self):
s = str(self)
return (s[0] if len(s) >= 1 else "", self.sorting_priority(), s)
def distance(self, other):
if len(self) == 0 and len(other) == 0:
return 0.0
t = len(self.inter(other))
return (len(self) + len(other) - t * 2) / (len(self) + len(other) - t)
def make_similar(self, other):
r = other.inter(self) + (self - other)
self.varlist = r.varlist
return self
def subsets(self, minsize = 0, maxsize = 100000, size = None, reverse = False):
"""Subsets of this random variable, as a generator
"""
return igen.subset(self, minsize = minsize, maxsize = maxsize, size = size, reverse = reverse)
@staticmethod
def parse(s):
"""Parse a string, e.g. "X+Y, Z W"
"""
r = RegionParser.parse_default("H(" + s + ")")
return iutil.ensure_comp(r)
def swapped_id(self, i, j):
if i >= len(self.varlist) or j >= len(self.varlist):
return self.copy()
r = self.copy()
r.varlist[i], r.varlist[j] = r.varlist[j], r.varlist[i]
return r
def set_algtype(self, algtype):
algtype = iutil.convert_algtype(algtype)
for a in self.varlist:
a.algtype = algtype
return self
def set_markers(self, markers):
for a in self.varlist:
if markers is None:
a.markers = None
else:
a.markers = markers[:]
return self
def add_markers(self, markers):
for a in self.varlist:
if a.markers is None:
a.markers = []
a.markers += markers
return self
def add_marker(self, key, value = 1):
self.add_markers([(key, value)])
return self
def add_marker_id(self, key):
return self.add_marker(key, iutil.get_count())
def mark(self, *args, **kwargs):
for a in args:
if a == "symm" or a == "disjoint" or a == "nonsubset":
self.add_marker_id(a)
else:
self.add_marker(a)
for key, value in kwargs.items():
self.add_marker(key, value)
return self
def get_marker_key(self, key):
for a in self.varlist:
if a.markers is not None:
for v, w in reversed(a.markers):
if v == key:
return w
return None
def write_marker_key(self, key, value):
for a in self.varlist:
if a.markers is None:
a.markers = []
found = False
for i in range(len(a.markers)):
if a.markers[i][0] == key:
a.markers[i] = (key, value)
found = True
if not found:
a.markers.append((key, value))
return None
def get_markers(self):
r = []
for a in self.varlist:
if a.markers is not None:
for v, w in a.markers:
if (v, w) not in r:
r.append((v, w))
return r
def set_card(self, m):
self.write_marker_key("card", m)
return self
def get_card(self):
r = 1
for a in self:
t = a.get_marker_key("card")
if t is None:
return None
r *= t
return r
def get_shape(self):
r = []
for a in self:
t = a.get_card()
if t is None:
raise ValueError("Cardinality of " + str(a) + " not set. Use " + str(a) + ".set_card(m) to set cardinality.")
return
r.append(t)
return tuple(r)
def get_shape_in(self):
return tuple()
def get_shape_out(self):
return self.get_shape()
@property
def card(self):
return self.get_card()
@card.setter
def card(self, value):
self.set_card(value)
@property
def shape(self):
return self.get_shape()
def get_index_shift(self):
return self.get_marker_key("index_shift")
def set_index_shift(self, x):
self.write_marker_key("index_shift", x)
def inc_index_shift(self, x):
c = self.get_index_shift()
if c is None:
c = 0
self.write_marker_key("index_shift", c + x)
def add_suffix(self, csuffix):
for a in self.varlist:
a.name += csuffix
def added_suffix(self, csuffix):
r = self.copy()
r.add_suffix(csuffix)
return r
def __getitem__(self, key):
r = self.varlist[key]
if isinstance(r, list):
return Comp(r)
return Comp([r])
def __setitem__(self, key, value):
if value.isempty():
del self.varlist[key]
self.varlist[key] = value.varlist[0]
def __delitem__(self, key):
del self.varlist[key]
def __iter__(self):
for a in self.varlist:
yield Comp([a])
def copy(self):
return Comp([a.copy() for a in self.varlist])
def copy_noreg(self):
return Comp([a.copy_noreg() for a in self.varlist])
def addvar(self, x):
if x in self.varlist:
return
self.varlist.append(x.copy())
def removevar(self, x):
self.varlist = [a for a in self.varlist if a != x]
def reg_excluded(self):
r = Comp.empty()
for a in self.varlist:
if a.reg is None:
r.varlist.append(a)
return r
def index_of(self, x):
for i, a in enumerate(self.varlist):
if x.ispresent(a):
return i
return -1
def ispresent_shallow(self, x):
if isinstance(x, str):
if x == "real":
return self.get_type() == IVarType.REAL
if x == "realvar":
for a in self.varlist:
if a.isrealvar():
return True
return False
if isinstance(x, Comp):
for y in x.varlist:
if y in self.varlist:
return True
return False
return x in self.varlist
def ispresent(self, x):
if isinstance(x, Expr) or isinstance(x, Region):
x = x.allcomp()
for a in self.varlist:
if a.reg is not None and a.reg.ispresent(x):
return True
return self.ispresent_shallow(x)
def __iadd__(self, other):
if isinstance(other, int):
c = self.get_index_shift()
if c is not None:
self.set_index_shift(c + other)
return self
for i in range(len(other.varlist)):
self.addvar(other.varlist[i])
return self
def __add__(self, other):
if isinstance(other, CompArray) or isinstance(other, ExprArray):
return NotImplemented
if isinstance(other, Expr):
return other.__radd__(self)
r = self.copy()
if isinstance(other, (Comp, int)):
r += other
return r
def __radd__(self, other):
r = self.copy()
if isinstance(other, (Comp, int)):
r += other
return r
# def attach(self, other, rv_add = None):
# t = self.copy()
# if rv_add is not None:
# t += rv_add
# return Region.universe().attach(other, rv_add = t)
def __isub__(self, other):
if isinstance(other, int):
self += -other
return self
for i in range(len(other.varlist)):
self.removevar(other.varlist[i])
return self
def __sub__(self, other):
r = self.copy()
r -= other
return r
def __floordiv__(self, other):
return BayesNet([self]) // other
def __xor__(self, other):
return BayesNet([self]) ^ other
def avoid(self, *args, samesuffix = True):
reg = Comp.empty()
for a in args:
if isinstance(a, IBaseObj):
reg += a.allcomp()
r = Region.universe().exists(self)
r.aux_avoid_from(reg, samesuffix = samesuffix)
return r.aux
def __imul__(self, other):
if iutil.isconstzero(other):
self.varlist = []
return self
if isinstance(other, Comp):
self.varlist = [a * b for a in self.varlist for b in other.varlist]
return self
def __mul__(self, other):
r = self.copy()
r *= other
return r
def __rmul__(self, other):
r = self.copy()
r *= other
return r
def __ipow__(self, other):
self.varlist = [a ** other for a in self.varlist]
return self
def __pow__(self, other):
r = self.copy()
r **= other
return r
def __truediv__(self, other):
return self * (other ** -1)
def __rtruediv__(self, other):
if isinstance(other, (int, float, fractions.Fraction)) and other == 1:
return self ** -1
return other * (self ** -1)
def product(self):
if len(self.varlist) == 0:
return 1
r = None
for a in self.varlist:
if r is None:
r = Comp([a]).copy()
else:
r *= Comp([a])
return r
def alg_region(self):
ws = []
was = []
algtypes = set(a.algtype for a in self.varlist)
r = Expr.zero()
for algtype in algtypes:
if algtype == 0:
continue
for a in self.varlist:
if a.algtype == algtype:
for x in a.get_alglist():
ws.append(x)
was.append(Comp([a]).copy())
fcns = IVar.word_fcnrelation(algtype, ws)
for t in fcns.fcn:
t2 = []
for x in t:
ct = Comp.empty()
for i, a in enumerate(was):
if x & (1 << i):
ct += a
t2.append(ct)
r += Expr.Hc(t2[1], t2[0])
if r.iszero():
return Region.universe()
return r <= 0
def get_indreg(self, skip_abscont = False):
if skip_abscont:
return Region.universe()
return self.alg_region()
def inter(self, other):
"""Intersection."""
return Comp([a for a in self.varlist if a in other.varlist])
def interleaved(self, other):
r = Comp([])
for i in range(max(len(self.varlist), len(other.varlist))):
if i < len(self.varlist):
r.varlist.append(self.varlist[i].copy())
if i < len(other.varlist):
r.varlist.append(other.varlist[i].copy())
return r
def size(self):
return len(self.varlist)
def __len__(self):
return len(self.varlist)
def __bool__(self):
return bool(self.varlist)
def isempty(self):
"""Whether self is empty."""
return (len(self.varlist) == 0)
def from_mask(self, mask):
"""Return subset using bit mask."""
r = []
for i in range(len(self.varlist)):
if mask & (1 << i) != 0:
r.append(self.varlist[i])
return Comp(r)
# Get bit mask of Comp
def get_mask(self, x):
r = 0
for i in range(len(self.varlist)):
if self.varlist[i] in x:
r |= (1 << i)
return r
def super_of(self, other):
"""Whether self is a superset of other."""
for i in range(len(other.varlist)):
if not (other.varlist[i] in self.varlist):
return False
return True
def super_of_index(self, other):
r = len(self.varlist)
for i in range(len(other.varlist)):
if not (other.varlist[i] in self.varlist):
return None
r = min(r, self.varlist.index(other.varlist[i]))
return r
def disjoint(self, other):
"""Whether self is disjoint from other."""
for i in range(len(other.varlist)):
if other.varlist[i] in self.varlist:
return False
return True
def __contains__(self, other):
if isinstance(other, IVar):
return other in self.varlist
if isinstance(other, Comp):
return self.super_of(other)
return False
def __ge__(self, other):
return self.super_of(other)
def __le__(self, other):
return other.super_of(self)
def __eq__(self, other):
return {a.name for a in self.varlist} == {a.name for a in other.varlist}
#return self.super_of(other) and other.super_of(self)
def __ne__(self, other):
return {a.name for a in self.varlist} != {a.name for a in other.varlist}
def __gt__(self, other):
return self.super_of(other) and not other.super_of(self)
def __lt__(self, other):
return other.super_of(self) and not self.super_of(other)
def tolist(self):
return CompArray([a.copy() for a in self])
def z3(self, vardict):
if z3 is None:
return None
n = vardict["#n"]
r = None
for a in self.varlist:
if r is None:
r = vardict[a.name]
else:
r = r | vardict[a.name]
if r is None:
return z3.BitVecVal(0, n)
return r
def mean(self, f = None, name = None):
"""
Returns the expectation of the function f.
Parameters
----------
f : function, numpy.array or torch.Tensor
If f is a function, the number of arguments must match the number of
random variables.
If f is an array or tensor, shape must match the shape of the
joint distribution.
Returns
-------
r : Expr
The expectation as an expression.
"""
if name is None:
name = "mean_" + str(iutil.get_count("mean"))
def fcncall(xdist):
return xdist.mean(f)
R = Expr.real(name)
return Expr.fromterm(Term(R.terms[0][0].x, Comp.empty(), Region.universe(), 0, fcncall, [self.copy()]))
def prob(self, *args):
"""
Returns the probability mass function at *args.
Parameters
----------
*args : int
The indices to query. E.g. (X+Y).prob(2,3) is P(X=2, Y=3).
Returns
-------
r : Expr
The probability as an expression.
"""
args = tuple(args)
name = "P(" + ",".join((a.tostring(style = PsiOpts.STR_STYLE_STANDARD) + "=" + str(b)) for a, b in zip(self, args)) + ")"
name += "@@" + str(PsiOpts.STR_STYLE_PSITIP) + "@@"
if len(self) == 1:
name += repr(self)
else:
name += "(" + repr(self) + ")"
name += ".prob(" + ",".join(str(b) for b in args) + ")"
name += "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@"
name += PsiOpts.settings["latex_prob"] + "(" + ",".join((a.tostring(style = PsiOpts.STR_STYLE_LATEX) + "=" + str(b)) for a, b in zip(self, args)) + ")"
def fcncall(xdist):
return xdist[args]
R = Expr.real(name)
return Expr.fromterm(Term(R.terms[0][0].x, Comp.empty(), Region.universe(), 0, fcncall, [self.copy()]))
def pmf(self, simplex = True):
""" Returns the probability mass function as ExprArray.
"""
shape_in = self.get_shape_in()
shape_out = self.get_shape_out()
n_shape_out = iutil.product(shape_out)
r = ExprArray.zeros(shape_in + shape_out)
for xs in itertools.product(*[range(x) for x in shape_in]):
tsum = Expr.zero()
for iy, ys in enumerate(itertools.product(*[range(y) for y in shape_out])):
if iy == n_shape_out - 1 and simplex:
r[xs + ys] = 1 - tsum
else:
r[xs + ys] = self.prob(*(xs + ys))
if simplex:
tsum += r[xs + ys]
return r
def sort(self):
self.varlist.sort(key = lambda a: a.name)
def tostring(self, style = 0, tosort = False, add_bracket = False):
"""Convert to string
Parameters:
style : Style of string conversion
STR_STYLE_STANDARD : I(X,Y;Z|W)
STR_STYLE_PSITIP : I(X+Y&Z|W)
"""
style = iutil.convert_str_style(style)
namelist = [a.tostring(style) for a in self.varlist]
if len(namelist) == 0:
if style & PsiOpts.STR_STYLE_PSITIP:
return "rv()"
elif style & PsiOpts.STR_STYLE_LATEX:
return PsiOpts.settings["latex_rv_empty"]
return "!"
if tosort:
namelist.sort()
r = ""
if add_bracket and len(namelist) > 1:
r += "("
if style & PsiOpts.STR_STYLE_PSITIP:
r += "+".join(namelist)
elif style & PsiOpts.STR_STYLE_LATEX:
r += (PsiOpts.settings["latex_rv_delim"] + " ").join(namelist)
else:
r += ",".join(namelist)
if add_bracket and len(namelist) > 1:
r += ")"
return r
def name_operation(self, ops):
r = self.copy()
for a in r.varlist:
a.name = iutil.text_operation(a.name, ops)
return r
def __str__(self):
return self.tostring(PsiOpts.settings["str_style"],
tosort = PsiOpts.settings["str_tosort"])
def __repr__(self):
return self.tostring(PsiOpts.settings["str_style_repr"])
@latex_postprocess
def _latex_(self):
return self.tostring(iutil.convert_str_style("latex"))
def __hash__(self):
#return hash(self.tostring(tosort = True))
return hash(frozenset(a.name for a in self.varlist))
def isregtermpresent(self):
for b in self.varlist:
if b.reg is not None:
return True
return False
def __and__(self, other):
if isinstance(other, CompArray) or isinstance(other, ExprArray):
return NotImplemented
return Term.H(self) & Term.H(other)
def __or__(self, other):
if isinstance(other, CompArray) or isinstance(other, ExprArray):
return NotImplemented
if isinstance(other, int):
other = Comp.empty()
return Term.H(self) | Term.H(other)
def rename_var(self, name0, name1):
for a in self.varlist:
if a.name == name0:
a.name = name1
for a in self.varlist:
if a.reg is not None:
a.reg.rename_var(name0, name1)
def rename_map(self, namemap):
"""Rename according to name map
"""
for a in self.varlist:
a.name = namemap.get(a.name, a.name)
for a in self.varlist:
if a.reg is not None:
a.reg.rename_map(namemap)
return self
@fcn_substitute
def substitute(self, v0, v1):
"""Substitute variable v0 by v1 (v1 can be compound)"""
if len(v0) > 1:
for v0c in v0:
self.substitute(v0c, v1)
return
v0s = v0.get_name()
for i in range(len(self.varlist)):
if self.varlist[i].name == v0s:
nameset = set()
for j in range(len(self.varlist)):
if j != i:
nameset.add(self.varlist[j].name)
self.varlist = self.varlist[:i] + [t.copy() for t in v1.varlist if t.name not in nameset] + self.varlist[i+1:]
break
# r = Comp.empty()
# for a in self.varlist:
# if v0.ispresent_shallow(a):
# r += v1
# else:
# r += Comp([a])
# self.varlist = r.varlist
for a in self.varlist:
if a.reg is not None:
a.reg.substitute(v0, v1)
@fcn_substitute
def substitute_whole(self, v0, v1):
"""Substitute variable v0 by v1 (v1 can be compound)"""
t = self.super_of_index(v0)
if t is None:
return
self -= v0
t = min(t, len(self.varlist))
for x in v1.varlist:
if x not in self.varlist:
self.varlist.insert(t, x.copy())
t += 1
@staticmethod
def substitute_list(a, vlist, suffix = "", isaux = False):
"""Substitute variables in vlist into a"""
if isinstance(vlist, list):
for v in vlist:
Comp.substitute_list(a, v, suffix, isaux)
elif isinstance(vlist, tuple):
if len(vlist) >= 2:
w = vlist[1]
if isinstance(w, list):
w = w[0] if len(w) > 0 else Comp.empty()
if suffix != "":
if isaux:
a.substitute_aux(Comp.rv(vlist[0].get_name() + suffix), w)
else:
a.substitute(Comp.rv(vlist[0].get_name() + suffix), w)
else:
if isaux:
a.substitute_aux(vlist[0], w)
else:
a.substitute(vlist[0], w)
@staticmethod
def substitute_list_to_dict(vlist, multi = False):
def add_dict(r, key, value):
if isinstance(value, list):
for t in value:
add_dict(r, key, t)
return
if key not in r:
r[key] = value
else:
if not isinstance(r[key], list):
r[key] = [r[key]]
if value not in r[key]:
r[key].append(value)
r = dict()
if isinstance(vlist, list):
for v in vlist:
t = Comp.substitute_list_to_dict(v, multi = multi)
if multi:
for key, value in t.items():
add_dict(r, key, value)
else:
r.update(t)
elif isinstance(vlist, tuple):
if len(vlist) >= 2:
if multi:
for w in vlist[1:]:
add_dict(r, vlist[0], w)
else:
w = vlist[1]
if isinstance(w, list):
w = w[0] if len(w) > 0 else Comp.empty()
r[vlist[0]] = w
return r
@staticmethod
def substitute_dict_ismulti(d):
return any(isinstance(value, list) for key, value in d.items())
def record_to(self, index):
index.record(self)
for a in self.varlist:
if a.reg is not None:
index.record(a.reg.allcomprv_noaux())
def fcn_of(self, b):
return Expr.Hc(self, b) == 0
def table(self, *args, **kwargs):
"""Plot the information diagram as a Karnaugh map.
"""
return universe().table(self, *args, **kwargs)
def venn(self, *args, **kwargs):
"""Plot the information diagram as a Venn diagram.
Can handle up to 5 random variables (uses Branko Grunbaum's Venn diagram for n=5).
"""
return universe().venn(self, *args, **kwargs)
def series(self, vdir):
"""Get past or future sequence.
Parameters:
vdir : Direction, 1: future non-strict, 2: future strict,
-1: past non-strict, -2: past strict
"""
return CompArray(list(self)).series(vdir)
def past_ns(self):
return CompArray(list(self)).past_ns()
def past(self):
return CompArray(list(self)).past()
def future_ns(self):
return CompArray(list(self)).future_ns()
def future(self):
return CompArray(list(self)).future()
@property
def PN(self):
return CompArray(list(self)).PN
@property
def P(self):
return CompArray(list(self)).P
@property
def FN(self):
return CompArray(list(self)).FN
@property
def F(self):
return CompArray(list(self)).F
class IVarIndex:
"""Store index of variables
Do NOT use this class directly
"""
def __init__(self):
self.dictrv = {}
self.comprv = Comp.empty()
self.dictreal = {}
self.compreal = Comp.empty()
self.prefavoid = "@@"
def copy(self):
r = IVarIndex()
r.dictrv = self.dictrv.copy()
r.comprv = self.comprv.copy()
r.dictreal = self.dictreal.copy()
r.compreal = self.compreal.copy()
r.prefavoid = self.prefavoid
return r
def record(self, x):
for i in range(len(x.varlist)):
if not x.varlist[i].name.startswith(self.prefavoid):
if x.varlist[i].vartype == IVarType.RV:
if not (x.varlist[i].name in self.dictrv):
self.dictrv[x.varlist[i].name] = self.comprv.size()
self.comprv.varlist.append(x.varlist[i])
else:
if not (x.varlist[i].name in self.dictreal):
self.dictreal[x.varlist[i].name] = self.compreal.size()
self.compreal.varlist.append(x.varlist[i])
def add_varindex(self, x):
self.record(x.comprv)
self.record(x.compreal)
# Get index of IVar
def get_index(self, x):
if isinstance(x, Comp):
x = x.varlist[0]
if x.vartype == IVarType.RV:
if x.name in self.dictrv:
return self.dictrv[x.name]
return -1
else:
if x.name in self.dictreal:
return self.dictreal[x.name]
return -1
def __contains__(self, other):
return self.get_index(other) >= 0
def ispresent(self, x):
"""Return whether any variable in x appears here"""
for y in x:
if self.get_index(y) >= 0:
return True
return False
# Get index of name
def get_index_name(self, name):
if name in self.dictrv:
return self.dictrv[name]
if name in self.dictreal:
return self.dictreal[name]
return -1
# Get object of name
def get_obj_name(self, name):
if name in self.dictrv:
return self.comprv[self.dictrv[name]]
if name in self.dictreal:
return Expr.fromcomp(self.compreal[self.dictreal[name]])
return None
def allcomp(self):
return self.comprv + self.compreal
def find(self, *args):
return self.allcomp().find(*args)
# Get bit mask of Comp
def get_mask(self, x):
if x.get_type() != IVarType.RV:
return 0
r = 0
for a in x.varlist:
k = self.get_index(a)
if k < 0:
return -1
r |= (1 << k)
return r
def from_mask(self, m):
return self.comprv.from_mask(m)
def num_rv(self):
return self.comprv.size()
def num_real(self):
return self.compreal.size()
def size(self):
return self.comprv.size() + self.compreal.size()
def name_avoid_old(self, name0):
name1 = name0
while self.get_index_name(name1) >= 0:
name1 += PsiOpts.settings["rename_char"]
return name1
def name_avoid(self, name0):
name1 = name0
rename_char = PsiOpts.settings["rename_char"]
k = 1
while self.get_index_name(name1) >= 0:
name1 = iutil.set_suffix_num(name0, k, rename_char, replace_mode = "append")
k += 1
return name1
def calc_rename_map(self, a):
m = dict()
for x in a:
xstr = ""
if isinstance(x, Comp):
xstr = x.get_name()
else:
xstr = str(x)
tname = self.name_avoid(xstr)
self.record(Comp.rv(tname))
m[xstr] = tname
return m
PsiOpts.global_index = IVarIndex()
class TermType:
NIL = 0
IC = 1
REAL = 2
REGION = 3
class TermAllowType:
H = 1
HC = 2
I = 4
IC = 8
I3 = 16
DEFAULT = 1 + 2 + 4 + 8 + 16
class Term(IBaseObj):
"""A term in an expression
Do NOT use this class directly. Use Expr instead
"""
def __init__(self, x, z = None, reg = None, sn = 0, fcncall = None, fcnargs = None, reg_outer = None, z_present = False, termtname = None):
self.x = x
if z is None:
self.z = Comp.empty()
else:
self.z = z
self.reg = reg
self.sn = sn
self.fcncall = fcncall
self.fcnargs = fcnargs
self.reg_outer = reg_outer
self.z_present = z_present
self.termtname = termtname
def copy(self):
return Term([a.copy() for a in self.x], self.z.copy(), iutil.copy(self.reg),
self.sn, self.fcncall, iutil.copy(self.fcnargs), iutil.copy(self.reg_outer), self.z_present, self.termtname)
def copy_noreg(self):
r = Term([a.copy_noreg() for a in self.x], self.z.copy_noreg(), None, 0)
r.termtname = self.termtname
return r
@staticmethod
def zero():
return Term([], Comp.empty())
def isempty(self):
"""Whether self is empty."""
return len(self.x) == 0 or any(len(a) == 0 for a in self.x)
def setzero(self):
self.x = []
self.z = Comp.empty()
self.reg = None
self.reg_outer = None
self.sn = 0
self.fcncall = None
self.fcnargs = None
self.z_present = False
self.termtname = None
def iszero(self):
if self.get_type() == TermType.REGION:
return False
else:
if len(self.x) == 0:
return True
for a in self.x:
if a.isempty():
return True
return False
@staticmethod
def fromcomp(x):
return Term([x.copy()], Comp.empty())
@staticmethod
def H(x):
return Term([x.copy()], Comp.empty())
@staticmethod
def I(x, y):
return Term([x.copy(), y.copy()], Comp.empty())
@staticmethod
def Hc(x, z):
return Term([x.copy()], z.copy())
@staticmethod
def Ic(x, y, z):
return Term([x.copy(), y.copy()], z.copy())
@staticmethod
def fcn(fcnname, fcncall, fcnargs):
cname = fcnname
return Term(Comp.real(cname), Comp.empty(), reg = Region.universe(),
sn = 0, fcncall = fcncall, fcnargs = fcnargs)
@staticmethod
def eps():
"""Epsilon."""
return Term([Comp([IVar.eps()])], Comp.empty())
@staticmethod
def one():
"""One."""
return Term([Comp([IVar.one()])], Comp.empty())
@staticmethod
def inf():
"""Infinity."""
return Term([Comp([IVar.inf()])], Comp.empty())
def allcomp(self):
r = Comp.empty()
for a in self.x:
r += a
r += self.z
return r
def allcomprv_shallow(self):
r = Comp.empty()
for a in self.x:
r += a
r += self.z
return Comp([a for a in r.varlist if a.vartype == IVarType.RV])
def get_name(self):
return self.allcomp().get_name()
def get_maxent_comp(self):
if self.termtname is None or self.termtname != "H0":
return None
if len(self.fcnargs) != 1:
return None
a = self.fcnargs[0]
if not isinstance(a, Term):
return None
if len(a.x) != 1 or not a.z.isempty():
return None
return a.x[0]
def size(self):
r = self.z.size()
for a in self.x:
r += a.size()
return r
def complexity(self):
# return (len(self.z) + sum(len(a) for a in self.x)) * 2 + len(self.x) + (not self.z.isempty())
r = len(self.z) + sum(len(a) for a in self.x) * 2
if len(self.x) == 1:
r += 2
elif len(self.x) >= 3:
r += len(self.x) * 2
if self.get_type() == TermType.REGION:
if self.reg is not None:
r += self.reg.complexity()
return r
def sorting_priority(self):
return self.complexity()
def get_type(self):
if self.reg is not None:
return TermType.REGION
if len(self.x) == 0:
return TermType.NIL
if self.x[0].get_type() == IVarType.REAL:
return TermType.REAL
return TermType.IC
def distance(self, other):
if self.get_type() != other.get_type():
return 1.0
if self.get_type() == TermType.REAL:
return self.x[0].distance(other.x[0])
if self.get_type() == TermType.IC:
r = iutil.distance(self.x, other.x)
return (r * len(self.x) + self.z.distance(other.z)) / (len(self.x) + 1)
return 1.0
def make_similar(self, other):
if self.get_type() != other.get_type():
return self
if self.get_type() == TermType.IC:
self.z.make_similar(other.z)
if len(self.x) > 1 and len(other.x) > 1:
ts = [[] for i in range(len(other.x))]
for a in self.x:
mind = None
mini = -1
for i in range(len(other.x)):
t = a.distance(other.x[i])
if mind is None or mind > t:
mind = t
mini = i
a.make_similar(other.x[mini])
ts[mini].append(a)
self.x = [t for t0 in ts for t in t0]
return self
return self
def iseps(self):
if self.get_type() != TermType.REAL:
return False
return self.x[0].varlist[0].name == "EPS"
def isone(self):
if self.get_type() != TermType.REAL:
return False
return self.x[0].varlist[0].name == "ONE"
def isinf(self):
if self.get_type() != TermType.REAL:
return False
return self.x[0].varlist[0].name == "INF"
def isrealvar(self):
if self.get_type() != TermType.REAL:
return False
return self.x[0].varlist[0].name != "EPS" and self.x[0].varlist[0].name != "ONE" and self.x[0].varlist[0].name != "INF"
def isnonneg(self):
if self.get_type() == TermType.IC:
if PsiOpts.settings.get("quantum", False):
if len(self.x) == 1:
return self.z.isempty()
elif len(self.x) == 2:
return self.z.isempty() or (self.x[0].inter(self.x[1])).isempty()
return False
else:
if len(self.x) == 0:
return True
elif len(self.x) == 1:
if self.z.isempty():
return PsiOpts.settings.get("hge0", False)
else:
return PsiOpts.settings.get("hcge0", False)
elif len(self.x) == 2:
if self.z.isempty():
return PsiOpts.settings.get("ige0", False)
else:
return PsiOpts.settings.get("icge0", False)
else:
return False
# return len(self.x) <= 2
if self.isone() or self.iseps() or self.isinf():
return True
return False
def isic2(self):
if self.get_type() == TermType.IC:
if len(self.x) != 2:
return False
return (self.x[0]-self.z).disjoint(self.x[1]-self.z)
#return ((self.x[0]+self.x[1]+self.z).size()
# == self.x[0].size()+self.x[1].size()+self.z.size())
return False
def isic3(self):
if self.get_type() == TermType.IC:
if len(self.x) != 3:
return False
return True
return False
def isicle2(self):
if self.get_type() == TermType.IC:
return len(self.x) <= 2
return False
def ishc(self):
if self.get_type() == TermType.IC:
if len(self.x) != 1:
return False
return True
return False
def isihc2(self):
if self.get_type() == TermType.IC:
if len(self.x) != 2:
return False
return True
return False
def ish(self):
if self.get_type() == TermType.IC:
if len(self.x) != 1:
return False
if not self.z.isempty():
return False
return True
return False
def restricted(self, a):
if self.get_type() == TermType.IC:
r = self.copy()
r.x = [y.inter(a) for y in r.x]
r.z = r.z.inter(a)
return r
return None
def record_to(self, index):
if self.get_type() == TermType.REGION:
if self.reg is not None:
index.record(self.reg.allcomprv_noaux())
if self.reg_outer is not None:
index.record(self.reg_outer.allcomprv_noaux())
if self.fcnargs is not None:
for a in self.fcnargs:
if isinstance(a, (IVar, Comp, Term, Expr, Region)):
a.record_to(index)
for a in self.x:
a.record_to(index)
self.z.record_to(index)
def definition(self):
"""Return the definition of this term.
"""
if self.get_type() == TermType.REGION:
cname = self.x[0].get_name()
if cname.find(PsiOpts.settings["fcn_suffix"]) >= 0:
return self.substituted(self.x[0], Comp.real(cname.replace(PsiOpts.settings["fcn_suffix"], "")))
return self.copy()
def z3(self, vardict):
if z3 is None:
return None
n = vardict["#n"]
H = vardict["#H"]
r = None
if self.get_type() == TermType.IC:
for mask in range(1 << len(self.x)):
csum = sum((self.x[i] for i in range(len(self.x)) if mask & (1 << i)), self.z.copy())
if csum.isempty():
continue
cH = H(csum.z3(vardict))
issub = iutil.bitcount(mask) % 2 == 0
if r is None:
if issub:
r = -cH
else:
r = cH
else:
if issub:
r = r - cH
else:
r = r + cH
return r
elif self.get_type() == TermType.REAL:
if self.isone():
return z3.RealVal(1)
return vardict[self.x[0].get_name()]
elif self.get_type() == TermType.REGION:
raise ValueError("Optimization quantities are not supported for Z3. Use another solver.")
def get_shape(self):
r = []
for a in itertools.chain(self.z, *(self.x)):
t = a.get_card()
if t is None:
raise ValueError("Cardinality of " + str(a) + " not set. Use " + str(a) + ".set_card(m) to set cardinality.")
return
r.append(t)
return tuple(r)
def get_shape_in(self):
r = []
for a in self.z:
t = a.get_card()
if t is None:
raise ValueError("Cardinality of " + str(a) + " not set. Use " + str(a) + ".set_card(m) to set cardinality.")
return
r.append(t)
return tuple(r)
def get_shape_out(self):
r = []
for a in itertools.chain(*(self.x)):
t = a.get_card()
if t is None:
raise ValueError("Cardinality of " + str(a) + " not set. Use " + str(a) + ".set_card(m) to set cardinality.")
return
r.append(t)
return tuple(r)
def prob(self, *args):
"""
Returns the conditional probability mass function at *args.
Parameters
----------
*args : int
The indices to query. E.g. (X+Y|Z).prob(2,3,4) is P(X=3, Y=4 | Z=2).
Returns
-------
r : Expr
The conditional probability as an expression.
"""
cx = []
for a in self.x:
cx += list(a)
args = tuple(args)
name = "P(" + ",".join((a.tostring(style = PsiOpts.STR_STYLE_STANDARD) + "=" + str(b)) for a, b in zip(cx, args[len(self.z):])) + "|"
name += ",".join((a.tostring(style = PsiOpts.STR_STYLE_STANDARD) + "=" + str(b)) for a, b in zip(self.z, args[:len(self.z)])) + ")"
name += "@@" + str(PsiOpts.STR_STYLE_PSITIP) + "@@"
name += "(" + "&".join(repr(a) for a in cx) + "|" + repr(self.z) + ")"
name += ".prob(" + ",".join(str(b) for b in args) + ")"
name += "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@"
name += PsiOpts.settings["latex_prob"] + "(" + ",".join((a.tostring(style = PsiOpts.STR_STYLE_LATEX) + "=" + str(b)) for a, b in zip(cx, args[len(self.z):])) + "|"
name += ",".join((a.tostring(style = PsiOpts.STR_STYLE_LATEX) + "=" + str(b)) for a, b in zip(self.z, args[:len(self.z)])) + ")"
def fcncall(xdist):
return xdist[args]
R = Expr.real(name)
return Expr.fromterm(Term(R.terms[0][0].x, Comp.empty(), Region.universe(), 0, fcncall, [self.copy()]))
# def fcncall(P):
# return P[self][args]
# R = Expr.real(name)
# return Expr.fromterm(Term(R.terms[0][0].x, Comp.empty(), Region.universe(), 0, fcncall, ["model"]))
def pmf(self, simplex = True):
""" Returns the probability mass function as ExprArray.
"""
shape_in = self.get_shape_in()
shape_out = self.get_shape_out()
n_shape_out = iutil.product(shape_out)
r = ExprArray.zeros(shape_in + shape_out)
for xs in itertools.product(*[range(x) for x in shape_in]):
tsum = Expr.zero()
for iy, ys in enumerate(itertools.product(*[range(y) for y in shape_out])):
if iy == n_shape_out - 1 and simplex:
r[xs + ys] = 1 - tsum
else:
r[xs + ys] = self.prob(*(xs + ys))
if simplex:
tsum += r[xs + ys]
return r
@staticmethod
def fcneval(fcncall, fcnargs):
if fcncall == "*":
return fcnargs[0] * fcnargs[1]
elif fcncall == "/":
return fcnargs[0] / fcnargs[1]
elif fcncall == "**":
return fcnargs[0] ** fcnargs[1]
else:
return fcncall(*fcnargs)
def get_fcneval(self, fcnargs):
return Term.fcneval(self.fcncall, fcnargs)
def value(self, method = "", num_iter = 30, prog = None):
if self.isone():
return 1.0
if self.iseps():
return 0.0
if self.isinf():
return float("inf")
if self.reg is None:
return None
if self.sn == 0:
return None
ms = method.split(",")
if isinstance(self.reg, RegionOp):
ms.append("bsearch")
if "bsearch" in ms:
selfterm = Expr.real(str(self))
return self.sn * iutil.gbsearch(
lambda x: self.reg.implies(self.sn * selfterm <= x),
num_iter = num_iter)
else:
selfterm = Expr.real(str(self))
cs = self.reg.consonly().imp_flipped()
index = IVarIndex()
cs.record_to(index)
r = []
dual_enabled = None
val_enabled = None
if "dual" in ms:
dual_enabled = True
if "val" in ms:
val_enabled = True
cprog = cs.init_prog(index, lp_bounded = False, dual_enabled = dual_enabled, val_enabled = val_enabled)
cprog.checkexpr_ge0(-self.sn * selfterm, optval = r)
if prog is not None:
prog.append(cprog)
if len(r) == 0:
return None
return r[0] * -self.sn
def get_reg_sgn_bds(self):
if self.get_type() != TermType.REGION:
return None
if self.sn == 0:
return None
reg = self.reg.copy()
if isinstance(reg, RegionOp):
reg = reg.tosimple()
if reg is None:
return None
sn = self.sn
tbds = reg.get_lb_ub_eq(self)
reg.eliminate_term(self)
reg.simplify_quick()
if sn > 0:
return (reg, sn, tbds[1])
else:
return (reg, sn, tbds[0])
def sort(self):
for a in self.x:
a.sort()
self.z.sort()
# self.x.sort(key = lambda a: (a.sorting_priority(), str(a)))
self.x.sort(key = lambda a: a.sorting_tuple())
def tostring(self, style = 0, tosort = False, add_bracket = False):
"""Convert to string
Parameters:
style : Style of string conversion
STR_STYLE_STANDARD : I(X,Y;Z|W)
STR_STYLE_PSITIP : I(X+Y&Z|W)
"""
style = iutil.convert_str_style(style)
termType = self.get_type()
if termType == TermType.NIL:
return "0"
elif termType == TermType.REAL:
return self.x[0].tostring(style = style, tosort = tosort)
elif termType == TermType.IC:
r = ""
if len(self.x) == 1:
if style & PsiOpts.STR_STYLE_LATEX:
r += PsiOpts.settings["latex_H"]
else:
r += "H"
else:
if style & PsiOpts.STR_STYLE_LATEX:
r += PsiOpts.settings["latex_I"]
else:
r += "I"
r += "("
namelist = [a.tostring(style = style, tosort = tosort) for a in self.x]
if tosort:
namelist.sort()
if style & PsiOpts.STR_STYLE_PSITIP:
r += "&".join(namelist)
elif style & PsiOpts.STR_STYLE_LATEX:
r += (PsiOpts.settings["latex_mi_delim"] + " ").join(namelist)
else:
r += ";".join(namelist)
if self.z.size() > 0:
if style & PsiOpts.STR_STYLE_LATEX:
r += PsiOpts.settings["latex_cond"]
else:
r += "|"
r += self.z.tostring(style = style, tosort = tosort)
r += ")"
return r
elif termType == TermType.REGION:
if self.x[0].varlist[0].name.find(PsiOpts.settings["fcn_suffix"]) >= 0:
return self.x[0].tostring(style = style, tosort = tosort)
reg = self.reg
sn = self.sn
bds = [self.copy_noreg()]
rsb = self.get_reg_sgn_bds()
if rsb is not None and (not style & PsiOpts.STR_STYLE_PSITIP
or len(rsb[2]) == 1 or rsb[0].isuniverse()):
reg, sn, bds = rsb
# elif rsb is not None and rsb[0].isuniverse():
# reg, sn, bds = rsb
# sn *= -1
reg_universe = reg.isuniverse(canon = True)
if len(bds) == 0:
if style & PsiOpts.STR_STYLE_LATEX:
if sn > 0:
return PsiOpts.settings["latex_infty"]
else:
if add_bracket:
return "(-" + PsiOpts.settings["latex_infty"] + ")"
else:
return "-" + PsiOpts.settings["latex_infty"]
else:
if sn > 0:
return "INF"
else:
if add_bracket:
return "(-" + "INF" + ")"
else:
return "-" + "INF"
if style & PsiOpts.STR_STYLE_LATEX:
r = ""
if not reg_universe:
if sn > 0:
r += PsiOpts.settings["latex_sup"]
else:
r += PsiOpts.settings["latex_inf"]
r += "_{"
r += reg.tostring(style = style & ~PsiOpts.STR_STYLE_LATEX_ARRAY,
tosort = tosort, small = True, skip_outer_exists = True)
r += "}"
if len(bds) > 1:
if sn > 0:
r += PsiOpts.settings["latex_min"]
else:
r += PsiOpts.settings["latex_max"]
r += "\\left("
r += ",\\, ".join(b.tostring(style = style, tosort = tosort, add_bracket = len(bds) == 1 and (add_bracket or not reg_universe)) for b in bds)
if len(bds) > 1:
r += "\\right)"
return r
else:
r = ""
if not reg_universe:
r += "("
r += reg.tostring(style = style, tosort = tosort, small = True)
r += ")"
if sn > 0:
r += ".maximum"
else:
r += ".minimum"
r += "("
if len(bds) > 1:
if style & PsiOpts.STR_STYLE_PSITIP:
if sn > 0:
r += "emin"
else:
r += "emax"
else:
if sn > 0:
r += "min"
else:
r += "max"
r += "("
r += ", ".join(b.tostring(style = style, tosort = tosort, add_bracket = len(bds) == 1 and (add_bracket or not reg_universe)) for b in bds)
if len(bds) > 1:
r += ")"
if not reg_universe:
r += ")"
return r
return ""
def __str__(self):
return self.tostring(PsiOpts.settings["str_style"],
tosort = PsiOpts.settings["str_tosort"])
def __repr__(self):
return self.tostring(PsiOpts.settings["str_style_repr"])
@latex_postprocess
def _latex_(self):
return self.tostring(iutil.convert_str_style("latex"))
def __hash__(self):
#return hash(self.tostring(tosort = True))
return hash((frozenset(hash(a) for a in self.x), hash(self.z)))
def simplify(self, reg = None, bnet = None):
if self.get_type() == TermType.IC:
for i in range(len(self.x)):
self.x[i] -= self.z
for i in range(len(self.x)):
if self.x[i].isempty():
self.x = []
self.z = Comp.empty()
return self
for i in range(len(self.x)):
for j in range(len(self.x)):
if i != j and (self.x[j] is not None) and self.x[i].super_of(self.x[j]):
self.x[i] = None
break
self.x = [a for a in self.x if a is not None]
if bnet is not None:
for i in range(len(self.x) + 1):
cc = None
if i == len(self.x):
cc = self.z
else:
if len(self.x) == 1:
continue
cc = self.x[i]
j = 0
while j < len(cc.varlist):
if bnet.check_ic(Expr.Ic(cc[j],
sum((self.x[i2] for i2 in range(len(self.x)) if i2 != i), Comp.empty()),
(Comp.empty() if i == len(self.x) else self.z) +
sum((cc[j2] for j2 in range(len(cc.varlist)) if j2 != j), Comp.empty()))):
cc.varlist.pop(j)
else:
j += 1
if len(self.x) == 3:
for i in range(len(self.x)):
if bnet.check_ic(Expr.Ic(self.x[(i + 1) % 3],
self.x[(i + 2) % 3],
self.x[i] + self.z)):
self.x.pop(i)
break
if len(self.x) == 2:
for i in range(2):
j = 0
while j < len(self.x[i]):
if bnet.check_ic(Expr.Ic(self.x[i][j], self.x[1 - i], self.z)):
self.z.varlist.append(self.x[i].varlist[j])
self.x[i].varlist.pop(j)
else:
j += 1
return self
def simplified(self, reg = None, bnet = None):
r = self.copy()
r.simplify(reg, bnet)
return r
def simplify_quick(self, **kwargs):
return self.simplify(**kwargs)
def simplified_quick(self, **kwargs):
return self.simplified(**kwargs)
def simplify_regterm_perform(self, reg = None):
# print("A")
# print(self)
# print(self.reg)
did = False
if self.reg is not None:
self.reg.simplify(reg = reg)
did = True
if self.reg_outer is not None:
self.reg_outer.simplify(reg = reg)
did = True
# print(self.reg)
return did
def simplify_regterm_expr(self, reg = None):
if self.simplify_regterm_perform(reg = reg):
if self.reg_outer is None:
rsb = self.get_reg_sgn_bds()
if rsb is not None:
# print(rsb)
reg, sn, bds = rsb
if len(bds) <= 1 and reg.isuniverse(canon = True):
if len(bds) == 0:
return Expr.inf() * (1 if sn > 0 else -1)
else:
return bds[0]
return Expr.fromterm(self)
return None
def match_x(self, other):
viss = [-1] * len(self.x)
viso = [-1] * len(other.x)
for i in range(len(self.x)):
for j in range(len(other.x)):
if viso[j] < 0 and self.x[i] == other.x[j]:
viss[i] = j
viso[j] = i
break
return (viss, viso)
def __eq__(self, other):
if self.z != other.z:
return False
if len(self.x) != len(other.x):
return False
viso = [-1] * len(other.x)
for i in range(len(self.x)):
found = False
for j in range(len(other.x)):
if viso[j] < 0 and self.x[i] == other.x[j]:
found = True
viso[j] = i
break
if not found:
return False
# (viss, viso) = self.match_x(other)
# if -1 in viso:
# return False
return True
def rv_weight(self):
r = 0.0
if self.get_type() == TermType.IC:
index = IVarIndex()
self.record_to(index)
k = len(self.x)
for t in range(1 << k):
csgn = -1
mask = index.get_mask(self.z)
for i in range(k):
if (t & (1 << i)) != 0:
csgn = -csgn
mask |= index.get_mask(self.x[i])
# r += csgn * math.sqrt(iutil.bitcount(mask))
maskn = iutil.bitcount(mask)
r += csgn * (1.0 - 1.0 / (maskn + 1))
return r
def try_remove(self, x, sn):
if not self.ispresent(x):
return True
if self.get_type() == TermType.IC:
if len(self.x) >= 3:
return False
if len(self.x) == 2 and self.z.ispresent(x):
return False
if sn > 0:
if any(x2.ispresent(x) for x2 in self.x):
return False
if sn < 0:
if self.z.ispresent(x):
return False
for x2 in self.x:
x2 -= x
self.z -= x
return True
else:
return False
def z_can_extend_to(self, target, bnet = None):
if self.z == target:
return True
if bnet is None:
return False
if not target.super_of(self.z):
return False
return bnet.check_ic(Expr.Ic(target - self.z, sum(self.x, Comp.empty()), self.z))
def try_iadd(self, other, bnet = None, term_allow = None):
if self.iseps() and other.iseps():
return True
# if self.isinf():
# return True
if self.get_type() == TermType.IC and other.get_type() == TermType.IC:
# H(X|Y) + I(X;Y) = H(X)
if len(self.x) + 1 == len(other.x):
(viss, viso) = self.match_x(other)
if viso.count(-1) == 1:
j = viso.index(-1)
selfz = self.z + other.x[j] + other.z
if self.z_can_extend_to(selfz, bnet):
otherz = selfz - other.x[j]
if other.z_can_extend_to(otherz, bnet):
self.z = otherz
return True
# if other.x[j].disjoint(other.z) and self.z == other.x[j] + other.z:
# self.z -= other.x[j]
# return True
# H(X|Y) + H(Y) = H(X,Y)
if len(self.x) == len(other.x):
(viss, viso) = self.match_x(other)
if viss.count(-1) == 1 and viso.count(-1) == 1:
i = viss.index(-1)
j = viso.index(-1)
selfz = self.z + other.x[j] + other.z
if self.z_can_extend_to(selfz, bnet):
otherz = selfz - other.x[j]
if other.z_can_extend_to(otherz, bnet):
self.x[i] += other.x[j]
self.z = otherz
return True
# if other.x[j].disjoint(other.z) and self.z == other.x[j] + other.z:
# self.x[i] += other.x[j]
# self.z -= other.x[j]
# return True
return False
def try_isub(self, other, bnet = None, term_allow = None):
if self.get_type() == TermType.IC and other.get_type() == TermType.IC:
# H(X) - I(X;Y) = H(X|Y)
if len(self.x) + 1 == len(other.x) and (term_allow is None or term_allow & TermAllowType.HC):
selfz = self.z + other.z
if self.z_can_extend_to(selfz, bnet) and other.z_can_extend_to(selfz, bnet):
(viss, viso) = self.match_x(other)
if viso.count(-1) == 1:
j = viso.index(-1)
self.z = selfz + other.x[j]
return True
# H(X) - H(X|Y) = I(X;Y)
if len(self.x) == len(other.x) and (term_allow is None or
(term_allow & TermAllowType.I and (len(self.x) < 2 or term_allow & TermAllowType.I3))):
otherz = other.z + self.z
if other.z_can_extend_to(otherz, bnet):
(viss, viso) = self.match_x(other)
if viso.count(-1) == 0:
self.x.append(otherz - self.z)
return True
# H(X,Y) - H(X) = H(Y|X)
if len(self.x) == len(other.x) and (term_allow is None or term_allow & TermAllowType.HC):
selfz = self.z + other.z
if self.z_can_extend_to(selfz, bnet) and other.z_can_extend_to(selfz, bnet):
(viss, viso) = self.match_x(other)
if viss.count(-1) == 1 and viso.count(-1) == 1:
i = viss.index(-1)
j = viso.index(-1)
if (self.x[i].super_of(other.x[j])
or (len(self.x) == 2 and bnet is not None and
bnet.check_ic(Expr.Ic(other.x[j], self.x[1 - i], self.x[i] + selfz)))):
self.x[i] -= other.x[j]
self.z = selfz + other.x[j]
return True
# H(X,Y) - H(X|Y) = H(Y)
if len(self.x) == len(other.x):
(viss, viso) = self.match_x(other)
if viss.count(-1) == 1 and viso.count(-1) == 1:
i = viss.index(-1)
j = viso.index(-1)
if self.x[i].super_of(other.x[j]):
otherz = (self.x[i] - other.x[j]) + self.z
if other.z_can_extend_to(otherz, bnet):
self.x[i] -= other.x[j]
return True
# if self.x[i] - other.x[j] == other.z - self.z:
# self.x[i] -= other.x[j]
# return True
return False
def try_isub_flipsign(self, other, bnet = None, term_allow = None):
if self.get_type() == TermType.IC and other.get_type() == TermType.IC and (term_allow is None or term_allow & TermAllowType.IC):
# I(X+U & Y+Z) - I(X & Y) = I(U & Y | X) + I(X+U & Z | Y)
if len(self.x) == 2 and len(other.x) == 2:
selfz = self.z + other.z
if self.z_can_extend_to(selfz, bnet) and other.z_can_extend_to(selfz, bnet):
s0 = self.x[0]
s1 = self.x[1]
for it in range(2):
o0 = other.x[it]
o1 = other.x[1 - it]
if s0.super_of(o0) and s1.super_of(o1):
self.x[0] = s0 - o0
self.x[1] = o1
self.z = selfz + o0
other.x[0] = s0
other.x[1] = s1 - o1
other.z = selfz + o1
return True
return False
def __and__(self, other):
if isinstance(other, CompArray) or isinstance(other, ExprArray):
return NotImplemented
if isinstance(other, Comp):
other = Term.H(other)
return Term([a.copy() for a in self.x] + [a.copy() for a in other.x], self.z + other.z, z_present = self.z_present or other.z_present)
def __or__(self, other):
if isinstance(other, CompArray) or isinstance(other, ExprArray):
return NotImplemented
if isinstance(other, Comp):
other = Term.H(other)
if isinstance(other, int):
other = Term.H(Comp.empty())
return Term([a.copy() for a in self.x], self.z + other.allcomp(), z_present = True)
def symbol_split(self):
r = []
for a in self.x:
if len(r):
r.append("&")
r.append(a.copy())
if (not self.z.isempty()) or self.z_present:
r.append("|")
if not self.z.isempty():
r.append(self.z.copy())
return r
@staticmethod
def from_symbols(symbols, prefer_multi = False):
cs = []
for t in symbols:
if isinstance(t, Term):
t2 = t.symbol_split()
cs += t2
elif isinstance(t, Comp):
cs.append(t.copy())
elif isinstance(t, str) and (t == "&" or t == "|"):
cs.append(t)
elif isinstance(t, ConcDist):
return t
else:
cs.append(iutil.ensure_comp(t).copy())
r = Term.H(Comp.empty())
z_present = False
for t in cs:
if isinstance(t, str):
if t == "&":
r.x.append(Comp.empty())
elif t == "|":
z_present = True
elif isinstance(t, Comp):
if z_present:
r.z += t
else:
r.x[-1] += t
if prefer_multi and len(symbols) >= 2 and len(r.x) <= 1:
rsymbols = []
have_bar = False
for t in symbols:
if rsymbols and not have_bar:
rsymbols.append("&")
rsymbols.append(t)
if isinstance(t, str) and t == "|":
have_bar = True
elif isinstance(t, Term) and len(t.z):
have_bar = True
return Term.from_symbols(rsymbols, prefer_multi=False)
return r
def istight(self, canon = False):
if self.get_type() == TermType.REGION:
if self.reg_outer is None:
return True
if canon:
return False
else:
return self.reg_outer.implies(self.reg)
return True
def tighten(self):
if self.istight(canon = False):
self.reg_outer = None
def lu_bound(self, sn, name = None):
if self.get_type() != TermType.REGION:
return self.copy()
r = self.copy()
if name is None:
name = self.x[0].get_name()
if sn > 0:
name += "_LB"
else:
name += "_UB"
r.substitute(self.x[0], Comp.real(name))
if sn * self.sn < 0:
if r.reg_outer is not None:
r.reg = r.reg_outer
r.reg_outer = None
return r
def lower_bound(self, name = None):
return self.lu_bound(1, name = name)
def upper_bound(self, name = None):
return self.lu_bound(-1, name = name)
def ispresent(self, x):
"""Return whether any variable in x appears here"""
if isinstance(x, IVar):
x = Comp([x])
if not isinstance(x, str) and not isinstance(x, Comp):
x = x.allcomp()
if self.get_type() == TermType.REGION:
if self.reg is not None and self.reg.ispresent(x):
return True
if self.reg_outer is not None and self.reg_outer.ispresent(x):
return True
if self.fcnargs is not None:
for a in self.fcnargs:
if isinstance(a, (IVar, Comp, Term, Expr, Region)):
if a.ispresent(x):
return True
xlist = []
if isinstance(x, str):
xlist = [x]
else:
xlist = x.varlist
for y in xlist:
for a in self.x:
if a.ispresent(y):
return True
if self.z.ispresent(y):
return True
return False
def __contains__(self, other):
condition_included = PsiOpts.settings.get("condition_included", False)
x = other
if isinstance(x, IVar):
x = Comp([x])
if self.get_type() == TermType.REGION:
if self.reg is not None and x in self.reg:
return True
if self.reg_outer is not None and x in self.reg_outer:
return True
if self.fcnargs is not None:
for a in self.fcnargs:
if isinstance(a, (IVar, Comp, Term, Expr, Region)):
if x in a:
return True
if condition_included:
for a in self.x:
if x in a + self.z:
return True
else:
for a in self.x:
if x in a:
return True
if x in self.z:
return True
return False
def rename_var(self, name0, name1):
if self.get_type() == TermType.REGION:
if self.reg is not None:
self.reg.rename_var(name0, name1)
if self.reg_outer is not None:
self.reg_outer.rename_var(name0, name1)
if self.fcnargs is not None:
for a in self.fcnargs:
if isinstance(a, (IVar, Comp, Term, Expr, Region)):
a.rename_var(name0, name1)
for a in self.x:
a.rename_var(name0, name1)
self.z.rename_var(name0, name1)
def rename_map(self, namemap):
"""Rename according to name map
"""
if self.get_type() == TermType.REGION:
if self.reg is not None:
self.reg.rename_map(namemap)
if self.reg_outer is not None:
self.reg_outer.rename_map(namemap)
if self.fcnargs is not None:
for a in self.fcnargs:
if isinstance(a, (IVar, Comp, Term, Expr, Region)):
a.rename_map(namemap)
for a in self.x:
a.rename_map(namemap)
self.z.rename_map(namemap)
return self
@fcn_substitute
def substitute(self, v0, v1):
"""Substitute variable v0 by v1 (v1 can be compound)"""
if self.get_type() == TermType.REGION:
if self.reg is not None:
self.reg.substitute(v0, v1)
if self.reg_outer is not None:
self.reg_outer.substitute(v0, v1)
if self.fcnargs is not None:
for a in self.fcnargs:
if isinstance(a, (IVar, Comp, Term, Expr, Region)):
a.substitute(v0, v1)
for a in self.x:
a.substitute(v0, v1)
self.z.substitute(v0, v1)
@fcn_substitute
def substitute_whole(self, v0, v1):
"""Substitute variable v0 by v1 (v1 can be compound)"""
condition_included = PsiOpts.settings.get("condition_included", False)
if self.get_type() == TermType.REGION:
if self.reg is not None:
self.reg.substitute_whole(v0, v1)
if self.reg_outer is not None:
self.reg_outer.substitute_whole(v0, v1)
if self.fcnargs is not None:
for a in self.fcnargs:
if isinstance(a, (IVar, Comp, Term, Expr, Region)):
a.substitute_whole(v0, v1)
if condition_included:
for a in self.x:
a += self.z
for a in self.x:
a.substitute_whole(v0, v1)
self.z.substitute_whole(v0, v1)
if condition_included:
for a in self.x:
a -= self.z
def get_var_avoid(self, v):
r = None
for a in self.x + [Comp.empty()]:
b = a + self.z
if b.ispresent(v):
b = b - v
if r is None:
r = b
else:
r = r.inter(b)
return r
class Expr(IBaseObj):
"""An expression
"""
def __init__(self, terms, mhash = None, meta = None):
self.terms = terms
self.mhash = mhash
self.meta = meta
def copy(self):
return Expr([(a.copy(), c) for (a, c) in self.terms], self.mhash, iutil.copy(self.meta))
def copy_(self, other):
self.terms = [(a.copy(), c) for (a, c) in other.terms]
self.mhash = other.mhash
self.meta = iutil.copy(other.meta)
def copy_noreg(self):
return Expr([(a.copy_noreg(), c) for (a, c) in self.terms], None)
@staticmethod
def parse(s):
"""Parse a string, e.g. I(X;Y,Z|W) + 2H(X Z)
"""
return iutil.ensure_expr(RegionParser.parse_default(s))
@staticmethod
def fromcomp(x):
return Expr([(Term.fromcomp(x), 1.0)])
@staticmethod
def fromterm(x):
return Expr([(x.copy(), 1.0)])
@staticmethod
def fcn(fcncall, name = None):
"""Wrap any function mapping a ConcModel to a number as an Expr.
E.g. the Hamming distortion is given by
Expr.fcn(lambda P: P[X+Y].mean(lambda x, y: float(x != y))).
For optimization using PyTorch, the return value should be a scalar
torch.Tensor with gradient information.
"""
if name is None:
name = "fcn_" + str(iutil.get_count("fcn"))
return Expr.fromterm(Term([Comp.real(name)], Comp.empty(), Region.universe(), 0, fcncall, ["model"]))
def find(self, *args):
return self.allcomp().find(*args)
def get_const(self, only = True):
r = 0.0
for (a, c) in self.terms:
if a.isone():
r += c
else:
if only:
return None
return r
def split_ic_real(self):
r = [Expr.zero(), Expr.zero()]
for (a, c) in self.terms:
if a.get_type() == TermType.IC:
r[0].terms.append((a, c))
else:
r[1].terms.append((a, c))
return r
def get_name(self):
return self.allcomp().get_name()
def __len__(self):
return len(self.terms)
def len_iccount(self):
r = 0
for (a, c) in self.terms:
if a.get_type() == TermType.IC:
r += len(a.x)
else:
r += 1
return r
def get_maxent_comp(self):
if len(self.terms) != 1:
return None
return self.terms[0][0].get_maxent_comp()
def __bool__(self):
return bool(self.terms)
def __getitem__(self, key):
r = self.terms[key]
if isinstance(r, list):
return Expr(r)
return Expr([r])
def allcomprealvar(self):
r = Comp.empty()
for a, c in self.terms:
if a.isrealvar() and a.fcncall is None:
r += a.x[0]
return r
def allcompreal_exprlist(self):
r = ExprArray([])
for a, c in self.terms:
if a.get_type() == TermType.REAL:
r.iadd_noduplicate([Expr.fromterm(a)])
return r
def allcomprealvar_exprlist(self):
r = ExprArray([])
for a, c in self.terms:
if a.isrealvar() and a.fcncall is None:
r.iadd_noduplicate([Expr.fromterm(a)])
return r
def __iadd__(self, other):
if isinstance(other, Comp):
other = Expr.fromcomp(other)
other = iutil.ensure_expr(other)
self.terms += [(a.copy(), c) for (a, c) in other.terms]
self.mhash = None
return self
def __add__(self, other):
if isinstance(other, CompArray) or isinstance(other, ExprArray):
return NotImplemented
if isinstance(other, Comp):
other = Expr.fromcomp(other)
other = iutil.ensure_expr(other)
return Expr([(a.copy(), c) for (a, c) in self.terms]
+ [(a.copy(), c) for (a, c) in other.terms],
meta = iutil.meta_concat([self.meta, other.meta]))
def __radd__(self, other):
if isinstance(other, Comp):
other = Expr.fromcomp(other)
other = iutil.ensure_expr(other)
return Expr([(a.copy(), c) for (a, c) in other.terms]
+ [(a.copy(), c) for (a, c) in self.terms],
meta = iutil.meta_concat([other.meta, self.meta]))
def __neg__(self):
return Expr([(a.copy(), -c) for (a, c) in self.terms], None, iutil.copy(self.meta))
def __isub__(self, other):
if isinstance(other, Comp):
other = Expr.fromcomp(other)
other = iutil.ensure_expr(other)
self.terms += [(a.copy(), -c) for (a, c) in other.terms]
self.mhash = None
return self
def __sub__(self, other):
if isinstance(other, CompArray) or isinstance(other, ExprArray):
return NotImplemented
if isinstance(other, Comp):
other = Expr.fromcomp(other)
other = iutil.ensure_expr(other)
return Expr([(a.copy(), c) for (a, c) in self.terms]
+ [(a.copy(), -c) for (a, c) in other.terms],
meta = iutil.meta_concat([self.meta, other.meta]))
def __rsub__(self, other):
if isinstance(other, Comp):
other = Expr.fromcomp(other)
other = iutil.ensure_expr(other)
return Expr([(a.copy(), c) for (a, c) in other.terms]
+ [(a.copy(), -c) for (a, c) in self.terms],
meta = iutil.meta_concat([other.meta, self.meta]))
def __imul__(self, other):
if isinstance(other, Expr):
tother = other.get_const()
if tother is None:
return self * other
# raise ValueError("Multiplication with non-constant expression is not supported.")
else:
other = tother
self.terms = [(a, c * other) for (a, c) in self.terms]
self.mhash = None
return self
def __mul__(self, other):
if isinstance(other, CompArray) or isinstance(other, ExprArray):
return NotImplemented
if isinstance(other, int) or isinstance(other, float):
if other == 0:
return Expr.zero()
if isinstance(other, Expr):
tother = other.get_const()
if tother is None:
tself = self.get_const()
if tself is None:
# raise ValueError("Multiplication with non-constant expression is not supported.")
return Expr.fromterm(Term(Comp.real(
iutil.fcn_name_maker("*", [self, other], lname = PsiOpts.settings["latex_times"] + " ", infix = True)
), reg = Region.universe(), fcncall = "*", fcnargs = [self, other]))
else:
return other * tself
else:
other = tother
return Expr([(a.copy(), c * other) for (a, c) in self.terms], None, iutil.copy(self.meta))
def __rmul__(self, other):
if isinstance(other, int) or isinstance(other, float):
if other == 0:
return Expr.zero()
if isinstance(other, Expr):
tother = other.get_const()
if tother is None:
# raise ValueError("Multiplication with non-constant expression is not supported.")
return Expr.fromterm(Term(Comp.real(
iutil.fcn_name_maker("*", [other, self], lname = PsiOpts.settings["latex_times"] + " ", infix = True)
), reg = Region.universe(), fcncall = "*", fcnargs = [other, self]))
else:
other = tother
return Expr([(a.copy(), c * other) for (a, c) in self.terms], None, iutil.copy(self.meta))
def __itruediv__(self, other):
if isinstance(other, CompArray) or isinstance(other, ExprArray):
return NotImplemented
if isinstance(other, Expr):
tother = other.get_const()
if tother is None:
return self / other
# raise ValueError("In-place division with non-constant expression is not supported.")
else:
other = tother
self.terms = [(a, c / other) for (a, c) in self.terms]
self.mhash = None
return self
def __pow__(self, other):
if isinstance(other, CompArray) or isinstance(other, ExprArray):
return NotImplemented
return Expr.fromterm(Term(Comp.real(
iutil.fcn_name_maker("**", [self, other], lname = "^", infix = True, latex_group = True)
), reg = Region.universe(), fcncall = "**", fcnargs = [self, other]))
def __ipow__(self, other):
return self ** other
def __rpow__(self, other):
return Expr.const(other) ** self
def __invert__(self):
return self == 0
def record_to(self, index):
for (a, c) in self.terms:
a.record_to(index)
def regtermmap(self, cmap, recur):
return (self >= 0).regtermmap(cmap, recur)
def allcomp(self):
r = Comp.empty()
for (a, c) in self.terms:
r += a.allcomp()
return r
def allcomprv_shallow(self):
r = Comp.empty()
for (a, c) in self.terms:
r += a.allcomprv_shallow()
return r
def size(self):
return len(self.terms)
def iszero(self):
"""Whether the expression is zero"""
return len(self.terms) == 0
def setzero(self):
"""Set expression to zero"""
self.terms = []
self.mhash = None
def isnonneg(self):
"""Whether the expression is always nonnegative"""
for (a, c) in self.terms:
if abs(c) <= PsiOpts.settings["eps"]:
continue
if c < 0:
return False
if not a.isnonneg():
return False
return True
def isnonpos(self):
"""Whether the expression is always nonpositive"""
for (a, c) in self.terms:
if abs(c) <= PsiOpts.settings["eps"]:
continue
if c > 0:
return False
if not a.isnonneg():
return False
return True
def isnonneg_ic2(self):
for (a, c) in self.terms:
if abs(c) <= PsiOpts.settings["eps"]:
continue
if c < 0:
return False
if not a.isic2():
return False
return True
def isnonpos_ic2(self):
for (a, c) in self.terms:
if abs(c) <= PsiOpts.settings["eps"]:
continue
if c > 0:
return False
if not a.isic2():
return False
return True
def isnonpos_hc(self):
for (a, c) in self.terms:
if abs(c) <= PsiOpts.settings["eps"]:
continue
if c > 0:
return False
if not a.ishc():
return False
return True
def isrealvar(self):
if len(self.terms) != 1:
return False
a, c = self.terms[0]
if abs(c - 1.0) > PsiOpts.settings["eps"]:
return False
return a.isrealvar()
@staticmethod
def zero():
"""The constant zero expression."""
return Expr([])
@staticmethod
def H(x):
"""Entropy."""
return Expr([(Term.H(x), 1.0)])
@staticmethod
def I(x, y):
"""Mutual information."""
return Expr([(Term.I(x, y), 1.0)])
@staticmethod
def Hc(x, z):
"""Conditional entropy."""
return Expr([(Term.Hc(x, z), 1.0)])
@staticmethod
def Ic(x, y, z):
"""Conditional mutual information."""
return Expr([(Term.Ic(x, y, z), 1.0)])
@staticmethod
def real(name):
"""Real variable."""
if isinstance(name, IVar):
return Expr([(Term([Comp([name])], Comp.empty()), 1.0)])
if isinstance(name, Comp):
return Expr([(Term([name], Comp.empty()), 1.0)])
return Expr([(Term([Comp.real(name)], Comp.empty()), 1.0)])
@staticmethod
def eps():
"""Epsilon."""
return Expr([(Term([Comp([IVar.eps()])], Comp.empty()), 1.0)])
@staticmethod
def one():
"""One."""
return Expr([(Term([Comp([IVar.one()])], Comp.empty()), 1.0)])
@staticmethod
def inf():
"""Infinity."""
return Expr([(Term([Comp([IVar.inf()])], Comp.empty()), 1.0)])
@staticmethod
def const(c):
"""Constant."""
if abs(c) <= PsiOpts.settings["eps"]:
return Expr.zero()
return Expr([(Term([Comp([IVar.one()])], Comp.empty()), float(c))])
def distance(self, other):
n0 = len(self.terms)
n1 = len(other.terms)
def term_distance(a, b):
return a[0].distance(b[0]) * 0.8 + (abs(a[1] - b[1]) / (max(abs(a[1]), abs(b[1])) * 2)) * 0.2
m = [[term_distance(x, y) for y in other.terms] for x in self.terms]
return iutil.distance_frommat(m, n0, n1)
def make_similar(self, other):
if len(self.terms) > 1 and len(other.terms) > 1:
def term_distance(a, b):
return a[0].distance(b[0]) * 0.8 + (abs(a[1] - b[1]) / (max(abs(a[1]), abs(b[1])) * 2)) * 0.2
ts = [[] for i in range(len(other.terms))]
for a in self.terms:
mind = None
mini = -1
for i in range(len(other.terms)):
t = term_distance(a, other.terms[i])
if mind is None or mind > t:
mind = t
mini = i
a[0].make_similar(other.terms[mini][0])
ts[mini].append(a)
self.terms = [t for t0 in ts for t in t0]
self.mhash = None
return self
def commonpart_coeff(self, v, forbid_h = True):
r = 0.0
for (a, c) in self.terms:
if abs(c) <= PsiOpts.settings["eps"]:
continue
if a.get_type() == TermType.IC:
if not isinstance(v, list) and not isinstance(v, tuple):
if not a.z.ispresent(v) and all(t.ispresent(v) for t in a.x):
r += c
else:
if all(not t.ispresent(vt) for t in a.x for vt in v):
continue
# if len(a.x) > 2:
# return None
v2 = [vt for vt in v if not a.z.ispresent(vt)]
if len(v2) == 0:
continue
if any(all(not t.ispresent(vt) for vt in v2) for t in a.x):
continue
if len(a.x) > 2:
tr = (Expr.fromterm(Term(a.x[1:], a.z))
- Expr.fromterm(Term(a.x[1:], a.z+a.x[0]))).commonpart_coeff(v)
if tr is None:
return None
r += c * tr
else:
for vt in v2:
tpres = [t.ispresent(vt) for t in a.x]
if all(tpres):
if forbid_h:
r += c * numpy.inf
elif not any(tpres):
break
else:
r += c
return r
def ent_coeff(self, v):
r = 0.0
for (a, c) in self.terms:
if abs(c) <= PsiOpts.settings["eps"]:
continue
if a.get_type() == TermType.IC:
if all(t.ispresent(v) for t in a.x) and not a.z.ispresent(v):
r += c
return r
def var_mi_only(self, v):
return abs(self.ent_coeff(v)) <= PsiOpts.settings["eps"]
def z3(self, vardict):
if z3 is None:
return None
r = None
for (a, c) in self.terms:
t = a.z3(vardict)
if abs(c) != 1:
t = t * iutil.float_toz3(abs(c))
if c >= 0:
if r is None:
r = t
else:
r = r + t
else:
if r is None:
r = -t
else:
r = r - t
return r
def istight(self, canon = False):
return all(a.istight(canon) for a, c in self.terms)
def tighten(self):
for a, c in self.terms:
a.tighten()
def lu_bound(self, sn, name = None):
return Expr([(a.lu_bound(1 if sn * c >= 0 else -1, name = name), c) for a, c in self.terms])
def lower_bound(self, name = None):
return self.lu_bound(1, name = name)
def upper_bound(self, name = None):
return self.lu_bound(-1, name = name)
def get_reg_sgn_bds(self):
reg = Region.universe()
sn = 0
bds = []
rest = Expr.zero()
for (a, c) in self.terms:
if abs(c) <= PsiOpts.settings["eps"]:
continue
t = a.get_reg_sgn_bds()
if t is None:
rest.terms.append((a, c))
else:
if sn != 0:
return None
reg, sn, bds = t
bds = [b * c for b in bds]
if c < 0:
sn = -sn
return (reg, sn, [b + rest for b in bds])
def sort(self):
for x, c in self.terms:
x.sort()
self.terms.sort(key = lambda a: (-round(a[1] * 1000.0),
a[0].sorting_priority(), str(a[0])))
def sorting_tuple_eqn(self):
s = self.tostring_eqn(">=", lhsvar = "real")
return (self.sorting_priority(), s)
def tostring(self, style = 0, tosort = False, add_bracket = False, tosort_pm = False):
"""Convert to string
Parameters:
style : Style of string conversion
STR_STYLE_STANDARD : I(X,Y;Z|W)
STR_STYLE_PSITIP : I(X+Y&Z|W)
"""
style = iutil.convert_str_style(style)
termlist = self.terms
if tosort:
termlist = sorted(termlist, key=lambda a: (-round(a[1] * 1000.0),
a[0].tostring(style = style, tosort = tosort)))
elif tosort_pm:
termlist = sorted(termlist, key=lambda a: a[1] < 0)
use_bracket = add_bracket and len(termlist) >= 2
float_style = self.get_meta("float_style")
r = ""
if use_bracket:
if style & PsiOpts.STR_STYLE_LATEX:
r += "\\left("
else:
r += "("
first = True
for (a, c) in termlist:
if abs(c) <= PsiOpts.settings["eps"]:
continue
if c > 0.0 and not first:
r += "+"
if a.isone():
r += iutil.float_tostr(c, style, bracket = False, force_float = float_style)
else:
need_bracket = False
if abs(c - 1.0) < PsiOpts.settings["eps"]:
pass
elif abs(c + 1.0) < PsiOpts.settings["eps"]:
r += "-"
need_bracket = True
else:
r += iutil.float_tostr(c, style, force_float = float_style)
if style & PsiOpts.STR_STYLE_PSITIP:
r += "*"
need_bracket = True
r += a.tostring(style = style, tosort = tosort, add_bracket = need_bracket)
first = False
if r == "":
return "0"
if use_bracket:
if style & PsiOpts.STR_STYLE_LATEX:
r += "\\right)"
else:
r += ")"
return r
def __str__(self):
return self.tostring(PsiOpts.settings["str_style"],
tosort = PsiOpts.settings["str_tosort"])
def __repr__(self):
if PsiOpts.settings.get("repr_simplify", False):
return self.simplified_quick().tostring(PsiOpts.settings["str_style_repr"])
return self.tostring(PsiOpts.settings["str_style_repr"])
@latex_postprocess
def _latex_(self):
if PsiOpts.settings.get("repr_simplify", False):
return self.simplified_quick().tostring(iutil.convert_str_style("latex"))
return self.tostring(iutil.convert_str_style("latex"))
def __hash__(self):
if self.mhash is None:
#self.mhash = hash(self.tostring(tosort = True))
self.mhash = hash(tuple(sorted((hash(a), c) for a, c in self.terms)))
return self.mhash
def table(self, *args, **kwargs):
"""Plot the information diagram as a Karnaugh map.
"""
return universe().table(*args, self, **kwargs)
def venn(self, *args, **kwargs):
"""Plot the information diagram as a Venn diagram.
Can handle up to 5 random variables (uses Branko Grunbaum's Venn diagram for n=5).
"""
return universe().venn(*args, self, **kwargs)
def isregtermpresent(self):
for (a, c) in self.terms:
rvs = a.allcomprv_shallow()
for b in rvs.varlist:
if b.reg is not None:
return True
if a.get_type() == TermType.REGION:
return True
return False
def sortIc(self):
def sortkey(a):
x = a[0]
if x.isic2():
return x.size()
else:
return 100000
self.terms.sort(key=sortkey)
self.mhash = None
def __le__(self, other):
if isinstance(other, CompArray) or isinstance(other, ExprArray):
return NotImplemented
other = iutil.ensure_expr(other)
return Region([other - self], [], Comp.empty(), Comp.empty(), Comp.empty())
def __lt__(self, other):
if isinstance(other, CompArray) or isinstance(other, ExprArray):
return NotImplemented
other = iutil.ensure_expr(other)
return Region([other - self - Expr.eps()], [], Comp.empty(), Comp.empty(), Comp.empty())
def __ge__(self, other):
if isinstance(other, CompArray) or isinstance(other, ExprArray):
return NotImplemented
other = iutil.ensure_expr(other)
return Region([self - other], [], Comp.empty(), Comp.empty(), Comp.empty())
def __gt__(self, other):
if isinstance(other, CompArray) or isinstance(other, ExprArray):
return NotImplemented
other = iutil.ensure_expr(other)
return Region([self - other - Expr.eps()], [], Comp.empty(), Comp.empty(), Comp.empty())
def __eq__(self, other):
if isinstance(other, CompArray) or isinstance(other, ExprArray):
return NotImplemented
other = iutil.ensure_expr(other)
return Region([], [other - self], Comp.empty(), Comp.empty(), Comp.empty())
def __ne__(self, other):
if isinstance(other, CompArray) or isinstance(other, ExprArray):
return NotImplemented
#return RegionOp.union([self > other, self < other])
return ~RegionOp.inter([self == other])
def equiv(self, other):
"""Whether self is equal to other"""
return (self <= other).check() and (other <= self).check()
def real_present(self):
for (a, c) in self.terms:
if a.get_type() == TermType.REAL:
return True
return False
def rv_weight(self):
r = 0.0
for (a, c) in self.terms:
r += a.rv_weight() * c
return r
def complexity(self):
max_denom = PsiOpts.settings["max_denom"]
r = 0
for (a, c) in self.terms:
frac = fractions.Fraction(c).limit_denominator(max_denom)
r += min(abs(frac.numerator) + abs(frac.denominator), 8)
r += a.complexity() * 4
if any(a.ishc() for a, c in self.terms):
r += 500
r += len(self.allcomp()) * 1000
return r
def sorting_priority(self):
return int(not self.real_present()) * 100000 + self.complexity()
def get_coeff(self, b, get_pos = False):
"""Get coefficient of Term b"""
if isinstance(b, Term):
r = 0.0
pos = None
for i, (a, c) in enumerate(self.terms):
if a == b:
r += c
pos = i
if get_pos:
return (r, pos)
return r
if isinstance(b, Comp):
r = 0.0
pos = None
for i, (a, c) in enumerate(self.terms):
if a.ishc() and (a.x[0] - a.z).super_of(b):
r += c
pos = i
if get_pos:
return (r, pos)
return r
r = None
pos = None
for a2, c2 in b.terms:
coeff2 = b.get_coeff(a2)
if coeff2 == 0:
continue
coeff1, tpos = self.get_coeff(a2, get_pos = True)
if coeff1 == 0:
if get_pos:
return (0.0, None)
return 0.0
if pos is None:
pos = tpos
elif tpos is not None:
pos = min(pos, tpos)
t = coeff1 / coeff2
if r is None:
r = t
else:
if r > 0:
if t < 0:
if get_pos:
return (0.0, None)
return 0.0
r = min(r, t)
else:
if t > 0:
if get_pos:
return (0.0, None)
return 0.0
r = max(r, t)
if r is None:
if get_pos:
return (0.0, None)
return 0.0
if get_pos:
return (r, pos)
return r
def __contains__(self, other):
if isinstance(other, (Term, Expr)):
return self.get_coeff(other) != 0
for a, c in self.terms:
if other in a:
return True
return False
def partition(self, p):
self.mhash = None
t = self.terms
self.terms = []
for a, c in t:
a2 = [a.restricted(px) for px in p]
if any(b is None for b in a2):
self.terms.append((a, c))
continue
for b in a2:
self.terms.append((b, c))
self.simplify_quick()
def partitioned(self, p):
r = self.copy()
r.partition(p)
return r
def isbalanced(self, v = None):
if v is None:
v = self.allcomprv_shallow()
# return all(self.get_coeff(x) == 0 for x in v)
ceps = PsiOpts.settings["eps"]
return all(abs(self.get_coeff(x)) <= ceps for x in v)
def balanced(self, v = None, w = None, sn = 1):
ceps = PsiOpts.settings["eps"]
if v is None:
v = self.allcomprv_shallow()
if w is None:
w = self.allcomprv_shallow()
r = self.copy()
cs = []
for x in v:
c = -r.get_coeff(x) * sn
if c < -ceps:
return None
if c > ceps:
cs.append((x, c))
while cs:
t = min(c for x, c in cs)
tx = sum((x for x, c in cs), Comp.empty())
tz = w - tx
r += Expr.Hc(tx, tz) * t * sn
cs = [(x, c - t) for x, c in cs if c - t > ceps]
return r.simplified_quick()
def get_sign(self, b):
"""Return whether Expr is increasing or decreasing in random variable b"""
r = 0
for (a, c) in self.terms:
sn = 0
if abs(c) <= PsiOpts.settings["eps"] or not a.ispresent(b):
continue
sn = 1 if c > 0 else -1
if a.get_type() == TermType.IC:
if len(a.x) > 2 or len(a.x) == 0:
return 0
nx = 0
nz = 0
if a.z.ispresent(b):
nz += 1
for x in a.x:
if x.ispresent(b):
nx += 1
if nx + nz != 1:
return 0
if len(a.x) == 2 and nz == 1:
return 0
if nz == 1:
sn = -sn
else:
return 0
if r != 0 and r != sn:
return 0
r = sn
return r
def get_var_avoid(self, x):
r = None
for (a, c) in self.terms:
t = a.get_var_avoid(x)
if r is None:
r = t
elif t is not None:
r = r.inter(t)
return r
def name_avoid(self, name0):
index = IVarIndex()
self.record_to(index)
return index.name_avoid(name0)
def isconvex(self, v = None, bnet = None):
"""Check whether expression is convex with respect to random variables v and real variables.
False return value does NOT necessarily mean expression is not convex.
"""
tmpvar = Expr.real(self.name_avoid("t"))
return (self <= tmpvar).isconvex(v, bnet)
def isconcave(self, v = None, bnet = None):
"""Check whether expression is concave with respect to random variables v and real variables.
False return value does NOT necessarily mean expression is not concave.
"""
tmpvar = Expr.real(self.name_avoid("t"))
return (self >= tmpvar).isconvex(v, bnet)
def isaffine(self, v = None, bnet = None):
"""Check whether expression is affine with respect to random variables v and real variables.
False return value does NOT necessarily mean expression is not affine.
"""
tmpvar = Expr.real(self.name_avoid("t"))
return (self <= tmpvar).isconvex(v, bnet) and (self >= tmpvar).isconvex(v, bnet)
def concave_envelope(self, v = None, bnet = None, q = None):
"""Compute the upper concave envelope with respect to random variables v.
C. Nair, "Upper concave envelopes and auxiliary random variables," International Journal of
Advances in Engineering Sciences and Applied Mathematics, vol. 5, no. 1, pp. 12-20, 2013.
"""
name = "env_" + str(iutil.get_count("env"))
tmpvar = Expr.real(self.name_avoid(name))
return (self >= tmpvar).convexified(v, bnet, q = q).maximum(tmpvar, None, allow_reuse = True)
def convex_envelope(self, v = None, bnet = None, q = None):
"""Compute the lower convex envelope with respect to random variables v.
C. Nair, "Upper concave envelopes and auxiliary random variables," International Journal of
Advances in Engineering Sciences and Applied Mathematics, vol. 5, no. 1, pp. 12-20, 2013.
"""
name = "env_" + str(iutil.get_count("env"))
tmpvar = Expr.real(self.name_avoid(name))
return (self <= tmpvar).convexified(v, bnet, q = q).minimum(tmpvar, None, allow_reuse = True)
def remove_term(self, b):
"""Remove Term b in place."""
self.terms = [(a, c) for (a, c) in self.terms if a != b]
self.mhash = None
return self
def removed_term(self, b):
"""Remove Term b, return Expr after removal."""
return Expr([(a, c) for (a, c) in self.terms if a != b], meta = iutil.copy(self.meta))
def symm_sort(self, terms):
"""Sort the random variables in terms assuming symmetry among those terms."""
self.mhash = None
index = IVarIndex()
terms.record_to(index)
terms = index.comprv
n = len(terms.varlist)
v = [0] * n
for (a, c) in self.terms:
cint = int(round(c * 1000))
for b, bc in [(t, 2) for t in a.x] + [(a.z, 1)]:
mask = index.get_mask(b)
count = iutil.bitcount(mask)
for i in range(n):
if mask & (1 << i):
v[i] += cint * bc + count * 5 + 11
vs = sorted(list(range(n)), key = lambda k: v[k], reverse = True)
tmpvar = Comp.array("#TMPVAR", 0, n)
for i in range(n):
self.substitute(terms[i], tmpvar[i])
for i in range(n):
self.substitute(tmpvar[i], terms[vs[i]])
def coeff_sum(self):
"""Sum of coefficients"""
return sum([c for (a, c) in self.terms])
def coeff_sign(self):
"""Sign of coefficients"""
p0 = False
p1 = False
for a, c in self.terms:
if c > 0:
p0 = True
elif c < 0:
p1 = True
if p0 and p1:
return 0
if p0:
return 1
if p1:
return -1
return 0
def simplify_mul(self, mul_allowed = 0):
self.mhash = None
if mul_allowed > 0:
max_denom = PsiOpts.settings["max_denom"]
max_denom_mul = PsiOpts.settings["max_denom_mul"]
denom = 1
for (a, c) in self.terms:
if a.isone():
continue
denom = iutil.lcm(fractions.Fraction(c).limit_denominator(
max_denom).denominator, denom)
if denom > max_denom_mul:
break
if denom > 0 and denom <= max_denom_mul:
if mul_allowed >= 2:
if self.coeff_sum() < 0:
denom = -denom
if all(a.isone() or abs(c * denom - round(c * denom)) <= PsiOpts.settings["eps"] for (a, c) in self.terms):
self.terms = [(a, iutil.float_snap(c * denom)) for (a, c) in self.terms]
num = None
for (a, c) in self.terms:
tnum = abs(fractions.Fraction(c).limit_denominator(max_denom).numerator)
if tnum == 0:
continue
if num is None:
num = tnum
else:
num = iutil.gcd(num, tnum)
if num is not None and num > 1:
self.terms = [(a, iutil.float_snap(c / num)) for (a, c) in self.terms]
def mi_disjoint(self):
i = 0
while i < len(self.terms):
a, c = self.terms[i]
if a.get_type() == TermType.IC and len(a.x) >= 2:
xt = a.x[0]
for j in range(1, len(a.x)):
xt = xt.inter(a.x[j])
if not xt.isempty():
self.terms.insert(i + 1, (Term.Hc(xt, a.z), c))
for j in range(len(a.x)):
a.x[j] = a.x[j] - xt
a.z = a.z + xt
i += 1
i += 1
def combine_same_terms(self):
ceps = PsiOpts.settings["eps"]
for i in range(len(self.terms)):
for j in range(i):
if self.terms[i][0] == self.terms[j][0]:
self.terms[j] = (self.terms[j][0], self.terms[j][1] + self.terms[i][1])
self.terms[i] = (self.terms[i][0], 0.0)
break
self.terms = [(a, c) for (a, c) in self.terms if abs(c) > ceps and not a.iszero()]
def simplify_balance(self, reg = None, bnet = None, quick = False, term_allow = None):
v = self.allcomprv_shallow()
ceps = PsiOpts.settings["eps"]
for x in v:
if abs(self.get_coeff(x)) > ceps:
continue
for i in range(len(self.terms)):
a, c = self.terms[i]
if a.ishc() and (a.x[0] - a.z).super_of(x):
# H(X,Y|Z) = H(Y|X,Z) + H(X|Z) = H(Y|X,Z) - I(X;Z) + H(X)
if not a.z.isempty():
self.terms.append((Term.I(x, a.z).copy(), -c))
a.x[0] -= x
a.z += x
self.terms = [(a, c) for (a, c) in self.terms if abs(c) > ceps and not a.iszero()]
def simplify(self, reg = None, bnet = None, quick = False, term_allow = None):
"""Simplify the expression in place"""
ceps = PsiOpts.settings["eps"]
reduce_coeff = PsiOpts.settings.get("simplify_reduce_coeff", False)
self.mhash = None
if not quick and PsiOpts.settings.get("simplify_regterm", False):
self.simplify_regterm(reg)
for (a, c) in self.terms:
a.simplify(reg, bnet)
self.mi_disjoint()
if PsiOpts.settings.get("simplify_balance", False):
self.simplify_balance(reg, bnet, quick, term_allow)
did = True
while did:
did = False
for i in range(len(self.terms)):
for j in range(i):
if self.terms[i][0] == self.terms[j][0]:
self.terms[j] = (self.terms[j][0], self.terms[j][1] + self.terms[i][1])
self.terms[i] = (self.terms[i][0], 0.0)
did = True
break
self.terms = [(a, c) for (a, c) in self.terms if abs(c) > ceps and not a.iszero()]
for i in range(len(self.terms)):
if abs(self.terms[i][1]) > ceps:
for j in range(len(self.terms)):
if i != j and abs(self.terms[j][1]) > ceps:
ci = self.terms[i][1]
cj = self.terms[j][1]
if abs(ci - cj) <= ceps:
if self.terms[i][0].try_iadd(self.terms[j][0], bnet = bnet, term_allow = term_allow):
self.terms[j] = (self.terms[j][0], 0.0)
did = True
elif reduce_coeff and ci * cj > 0:
if abs(ci) > abs(cj):
ti = self.terms[i][0].copy()
if self.terms[i][0].try_iadd(self.terms[j][0], bnet = bnet, term_allow = term_allow):
self.terms[i] = (self.terms[i][0], cj)
self.terms[j] = (ti, ci - cj)
did = True
else:
if self.terms[i][0].try_iadd(self.terms[j][0], bnet = bnet, term_allow = term_allow):
self.terms[j] = (self.terms[j][0], cj - ci)
did = True
elif abs(ci + cj) <= ceps:
if self.terms[i][0].try_isub(self.terms[j][0], bnet = bnet, term_allow = term_allow):
self.terms[j] = (self.terms[j][0], 0.0)
did = True
elif self.terms[i][0].try_isub_flipsign(self.terms[j][0], bnet = bnet, term_allow = term_allow):
self.terms[j] = (self.terms[j][0], ci)
did = True
elif reduce_coeff and ci * cj < 0:
if abs(ci) > abs(cj):
ti = self.terms[i][0].copy()
if self.terms[i][0].try_isub(self.terms[j][0], bnet = bnet, term_allow = term_allow):
self.terms[i] = (self.terms[i][0], -cj)
self.terms[j] = (ti, ci + cj)
did = True
else:
if self.terms[i][0].try_isub(self.terms[j][0], bnet = bnet, term_allow = term_allow):
self.terms[j] = (self.terms[j][0], cj + ci)
did = True
self.terms = [(a, c) for (a, c) in self.terms if abs(c) > ceps and not a.iszero()]
if did:
for (a, c) in self.terms:
a.simplify(reg, bnet)
#self.terms = [(a, iutil.float_snap(c)) for (a, c) in self.terms
# if abs(c) > ceps and not a.iszero()]
if term_allow is None and PsiOpts.settings["term_allow"] is not None:
self.simplify_break_allow(reg, bnet, quick, PsiOpts.settings["term_allow"])
return self
def simplify_break_allow(self, reg = None, bnet = None, quick = False, term_allow = None):
if (not term_allow & TermAllowType.I) or (not term_allow & TermAllowType.I3):
did = True
while did:
did = False
for i in range(len(self.terms)):
a, c = self.terms[i]
if a.get_type() != TermType.IC:
continue
if len(a.x) <= 1:
continue
if term_allow & TermAllowType.I and len(a.x) <= 2:
continue
minc = 0
mint = None
for k in range(len(a.x)):
t = list(self.terms)
a2 = a.copy()
a2.x.pop(k)
a3 = a2.copy()
a3.z += a.x[k]
t[i:i+1] = [(a2, c), (a3, -c)]
t = Expr(t)
t.simplify(reg, bnet, quick, term_allow)
tc = t.complexity()
if mint is None or tc < minc:
minc = tc
mint = t
self.copy_(mint)
did = True
break
if (not term_allow & TermAllowType.HC) or (not term_allow & TermAllowType.IC):
did = False
t = list(self.terms)
i = 0
while i < len(t):
a, c = t[i]
if a.get_type() != TermType.IC:
i += 1
continue
if len(a.z) == 0:
i += 1
continue
if len(a.x) <= 1:
if term_allow & TermAllowType.HC:
i += 1
continue
else:
if term_allow & TermAllowType.IC:
i += 1
continue
a2 = a.copy()
for k in range(len(a2.x)):
a2.x[k] += a.z
a2.z = Comp.empty()
t[i:i+1] = [(a2, c), (Term.H(a.z.copy()), -c)]
i += 2
did = True
if did:
t = Expr(t)
self.copy_(t)
self.simplify(reg, bnet, quick, term_allow)
return self
def simplified(self, reg = None, bnet = None, quick = False):
"""Simplify the expression, return simplified expression"""
r = self.copy()
r.simplify(reg, bnet, quick)
return r
def simplify_more_cond(self, bnet = None):
return None
def simplified_exhaust_inner(self, bnet = None):
if len(self.terms) <= 2:
return None
scom = self.complexity()
for i, (a, c) in enumerate(self.terms):
if a.get_type() == TermType.IC:
a2 = a.copy()
for ix in range(len(a2.x)):
if len(a2.x[ix]) <= 1:
continue
for x in a2.x[ix]:
ks = [Term([p - x if ip == ix else p.copy() for ip, p in enumerate(a2.x)], a2.z.copy()),
Term([x.copy() if ip == ix else p.copy() for ip, p in enumerate(a2.x)], a2.z + (a2.x[ix] - x))]
# print(str(a2) + " " + str(ks[0]) + " " + str(ks[1]))
for it in range(2):
expr = Expr([(a4.copy(), c4) if i4 != i else (ks[it].copy(), c) for i4, (a4, c4) in enumerate(self.terms)])
# print(" " + str(expr))
expr.simplify_quick(bnet = bnet)
# print(" " + str(expr))
expr += Expr([(ks[1 - it].copy(), c)])
# print(" " + str(expr))
expr.simplify_quick(bnet = bnet)
# print(" " + str(expr))
# print()
if expr.complexity() < scom:
return expr
return None
def break_hc(self):
self.mhash = None
olen = len(self.terms)
for i in range(olen):
a, c = self.terms[i]
if a.ishc() and not a.z.isempty():
self.terms.append((Term.I(a.x[0], a.z), -c))
a.z = Comp.empty()
def simplify_break_hc(self, bnet = None):
r = self.copy()
r.break_hc()
r.simplify_quick(bnet = bnet)
if r.complexity() < self.complexity():
self.copy_(r)
self.mhash = None
def simplified_break_hc(self, bnet = None):
r = self.copy()
r.simplify_break_hc(bnet = bnet)
return r
def simplify_exhaust(self, bnet = None):
self.terms.sort(key = lambda a: a[0].complexity())
self.simplify_break_hc(bnet = bnet)
self.terms.sort(key = lambda a: a[0].complexity())
self.mhash = None
while True:
t = self.simplified_exhaust_inner(bnet = bnet)
if t is None:
break
self.terms = t.terms
self.mhash = None
def simplified_exhaust(self, bnet = None):
r = self.copy()
r.simplify_exhaust(bnet = bnet)
return r
def simplify_target(self, target, bnet = None):
for term in target:
if term.terms[0][0].get_type() != TermType.IC:
continue
t = (self - term).simplified_quick(bnet = bnet) + term
if t.complexity() <= self.complexity():
self.copy_(t)
return self
def simplify_quick(self, reg = None, **kwargs):
return self.simplify(reg = reg, **kwargs, quick = True)
def simplified_quick(self, reg = None, **kwargs):
return self.simplified(reg = reg, **kwargs, quick = True)
def simplified_prog(self, prog = None, reg = None):
if prog is None:
index = IVarIndex()
reg.record_to(index)
self.record_to(index)
prog = reg.init_simplified_prog(index)
cself, creal = self.split_ic_real()
prog.clear_dual()
with PsiOpts(proof_enabled = False):
prog.checkexpr_ge0(cself)
r = prog.get_dual_sum_meta()
if r is None:
return self.copy()
return r.simplified_quick() + creal
# return r + creal
def get_ratio(self, other, skip_simplify = False):
"""Try dividing self by other, return None if self is not scalar multiple of other"""
es = self
eo = other
if not skip_simplify:
es = self.simplified_quick()
eo = other.simplified_quick()
if es.iszero():
return 0.0
if len(es.terms) != len(eo.terms):
return None
if eo.iszero():
return None
rmax = -1e12
rmin = 1e12
vis = [False] * len(eo.terms)
for i in range(len(es.terms)):
found = False
for j in range(len(eo.terms)):
if not vis[j] and abs(eo.terms[j][1]) > PsiOpts.settings["eps"] and es.terms[i][0] == eo.terms[j][0]:
cr = es.terms[i][1] / eo.terms[j][1]
rmax = max(rmax, cr)
rmin = min(rmin, cr)
if rmax > rmin + PsiOpts.settings["eps"]:
return None
vis[j] = True
found = True
break
if not found:
return None
if rmax <= rmin + PsiOpts.settings["eps"]:
return (rmax + rmin) * 0.5
else:
return None
def __truediv__(self, other):
if isinstance(other, Expr):
t = other.get_const()
if t is None:
t = self.get_ratio(other)
if t is not None:
return t
else:
return Expr.fromterm(Term(Comp.real(
iutil.fcn_name_maker("/", [self, other], lname = "/", infix = True)
), reg = Region.universe(), fcncall = "/", fcnargs = [self, other]))
other = t
return Expr([(a.copy(), c / other) for (a, c) in self.terms])
def __rtruediv__(self, other):
if isinstance(other, Expr):
return other / self
return Expr.const(other) / self
def ispresent(self, x):
"""Return whether any variable in x appears here"""
for (a, c) in self.terms:
if a.ispresent(x):
return True
return False
def affine_present(self):
"""Return whether this expression is affine."""
return self.ispresent((Expr.one() + Expr.eps() + Expr.inf()).allcomp())
def try_remove(self, x, sn):
for (a, c) in self.terms:
if not a.try_remove(x, sn * c):
return False
self.simplify_quick()
return True
def rename_var(self, name0, name1):
for (a, c) in self.terms:
a.rename_var(name0, name1)
self.mhash = None
def rename_map(self, namemap):
"""Rename according to name map
"""
for (a, c) in self.terms:
a.rename_map(namemap)
self.mhash = None
return self
def definition(self):
"""Return the definition of this expression.
"""
return Expr([(a.definition(), c) for a, c in self.terms])
def substitute_rate(self, v0, v1):
with PsiOpts(meta_subs_criteria = False):
self.mhash = None
for i, v0c in enumerate(v0):
if i > 0:
self.substitute(v0c, v0[0])
c = self.commonpart_coeff(v0[0])
self.substitute(v0[0], Comp.empty())
if c != 0:
self += v1 * c
return self
@fcn_substitute
def substitute(self, v0, v1):
"""Substitute variable v0 by v1 (v1 can be compound), in place"""
self.mhash = None
if isinstance(v0, Expr):
if len(v0.terms) > 0:
t = v0.terms[0][0]
tmpterms = self.terms
self.terms = []
for (a, c) in tmpterms:
if a == t:
self += v1 * c
else:
self.terms.append((a, c))
elif isinstance(v1, Expr):
self.substitute_rate(v0, v1)
else:
for (a, c) in self.terms:
a.substitute(v0, v1)
if iutil.check_meta_subs_criteria(v0, v1, self):
iutil.substitute(self.meta, v0, v1)
return self
@fcn_substitute
def substitute_whole(self, v0, v1):
"""Substitute variable v0 by v1 (v1 can be compound)"""
if isinstance(v0, Term):
v0 = Expr.fromterm(v0)
if isinstance(v1, Term):
v1 = Expr.fromterm(v1)
self.mhash = None
if isinstance(v0, Expr):
if len(v0.terms) > 0:
coeff, pos = self.get_coeff(v0, get_pos = True)
if coeff != 0 and pos is not None:
self.terms[pos:pos] = (v1 * coeff).terms
self.terms.extend((v0 * (-coeff)).terms)
self.combine_same_terms()
elif isinstance(v1, Expr):
self.substitute_rate(v0, v1)
else:
for (a, c) in self.terms:
a.substitute_whole(v0, v1)
if iutil.check_meta_subs_criteria(v0, v1):
iutil.substitute_whole(self.meta, v0, v1)
return self
def condition(self, b):
"""Condition on random variable b, in place"""
for (a, c) in self.terms:
if a.get_type() == TermType.IC:
a.z += b
self.mhash = None
return self
def conditioned(self, b):
"""Condition on random variable b, return result"""
r = self.copy()
r.condition(b)
return r
def var_neighbors(self, v):
r = v.copy()
for (a, c) in self.terms:
if a.get_type() == TermType.IC:
t = sum(a.x, Comp.empty()) + a.z
if t.ispresent(v):
r += t
elif a.get_type() == TermType.REGION:
r += a.reg.var_neighbors(v)
return r
def __abs__(self):
return eabs(self)
#return emax(self, -self)
def type_only(self, t):
return Expr([(a, c) for a, c in self.terms if a.get_type() == t])
def value(self, method = "", num_iter = 30, prog = None):
r = 0.0
for (a, c) in self.terms:
if a.isone():
r += c
else:
t = a.value(method = method, num_iter = num_iter, prog = prog)
if t is None:
return None
r += c * t
return r
def solve_prog(self, method = "", num_iter = 30):
prog = []
optval = self.value(method = method, num_iter = num_iter, prog = prog)
if len(prog) > 0:
return (optval, prog[0])
else:
return (optval, None)
def __call__(self, method = "", num_iter = 30, prog = None):
return self.value(method = method, num_iter = num_iter, prog = prog)
def __float__(self):
return float(self.value())
def __int__(self):
return int(round(float(self.value())))
def split_lhs(self, lhsvar):
lhs = []
rhs = []
for (a, c) in self.terms:
index = IVarIndex()
a.record_to(index)
t = index.size()
index2 = IVarIndex()
lhsvar.record_to(index)
lhsvar.record_to(index2)
if index.size() != t + index2.size():
lhs.append((a, c))
else:
rhs.append((a, c))
return (Expr(lhs), Expr(rhs))
def split_posneg(self):
lhs = []
rhs = []
for (a, c) in self.terms:
if c > 0:
rhs.append((a, c))
else:
lhs.append((a, c))
return (Expr(lhs), Expr(rhs))
def split_present(self, eqnstr, lhsvar = None, prefer_ge = None):
if isinstance(lhsvar, str) and lhsvar == "real":
lhsvar = self.allcomprealvar()
if prefer_ge is None:
prefer_ge = PsiOpts.settings["str_eqn_prefer_ge"]
eps_hide = PsiOpts.settings["str_eps_hide"]
cs = self
if eps_hide and eqnstr != "==":
eps_coeff = cs.get_coeff(Term.eps())
cs = cs.substituted(Expr.eps(), Expr.zero())
if eqnstr == ">=" and eps_coeff < 0:
eqnstr = ">"
elif eqnstr == "<=" and eps_coeff > 0:
eqnstr = "<"
lhs = cs
rhs = Expr.zero()
if lhsvar is not None:
lhs, rhs = cs.split_lhs(lhsvar)
if lhs.iszero() or rhs.iszero():
lhs, rhs = cs.split_posneg()
if prefer_ge:
lhs, rhs = rhs, lhs
if lhs.iszero():
lhs, rhs = rhs, lhs
elif lhs.get_const() is not None:
lhs, rhs = rhs, lhs
rhs *= -1.0
toflip = False
lhs_sign = lhs.coeff_sign()
if lhs_sign < 0:
toflip = True
elif lhs_sign == 0:
rhs_sign = rhs.coeff_sign()
if rhs_sign < 0:
toflip = True
elif rhs_sign == 0:
lhs_sum = lhs.coeff_sum()
if lhs_sum < 0:
toflip = True
elif lhs_sum == 0:
rhs_sum = rhs.coeff_sum()
if rhs_sum < 0:
toflip = True
if toflip:
lhs *= -1.0
rhs *= -1.0
eqnstr = iutil.reverse_eqnstr(eqnstr)
return (lhs, rhs, eqnstr)
def tostring_eqn(self, eqnstr, style = 0, tosort = False, lhsvar = None, prefer_ge = None, line_len = None, pf_note = None):
if pf_note is None:
pf_note = PsiOpts.settings["str_proof_note"]
if lhsvar is None:
lhsvar = PsiOpts.settings["lhsvar"]
style = iutil.convert_str_style(style)
if line_len is None:
if style & PsiOpts.STR_STYLE_LATEX:
line_len = PsiOpts.settings["latex_line_len"]
lhs, rhs, eqnstr = self.split_present(eqnstr, lhsvar, prefer_ge)
r = (lhs.tostring(style = style, tosort = tosort, tosort_pm = True) + " "
+ iutil.eqnstr_style(eqnstr, style) + " " + rhs.tostring(style = style, tosort = tosort, tosort_pm = True))
cnote = None
if pf_note:
cnote = self.get_meta("pf_note")
if cnote is not None:
cnote = iutil.pf_note_str(cnote, style, add_space = 4)
if cnote is not None:
r = [r, cnote]
if style & PsiOpts.STR_STYLE_LATEX:
r = iutil.latex_split_line(r, line_len)
else:
if isinstance(r, list):
r = "".join(r)
# if style & PsiOpts.STR_STYLE_PSITIP:
# r = r[0]
# else:
# r = "".join(r)
# if style & PsiOpts.STR_STYLE_LATEX and line_len is not None:
# r = iutil.latex_split_line(r, line_len)
return r
def tostring_line_len(self, eqnstr = "", style = 0, tosort = False, line_len = None):
style = iutil.convert_str_style(style)
if line_len is None:
if style & PsiOpts.STR_STYLE_LATEX:
line_len = PsiOpts.settings["latex_line_len"]
r = ""
if eqnstr != "":
r += iutil.eqnstr_style(eqnstr, style) + " "
r += self.tostring(style = style, tosort = tosort, tosort_pm = True)
if style & PsiOpts.STR_STYLE_LATEX and line_len is not None:
r = iutil.latex_split_line(r, line_len, slstr = "\\;" * 3)
return r
class FcnRelation(IBaseObj):
"""Stores functional dependencies"""
def __init__(self, fcn = None):
self.index = IVarIndex()
self.fcn = []
if fcn is not None:
self += fcn
def copy(self):
r = FcnRelation()
r.index = self.index.copy()
r.fcn = list(self.fcn)
return r
def check_fcn(self, x, y):
if y | x == x:
return True
i = 0
elapsed = 0
while elapsed < len(self.fcn):
if self.fcn[i][0] | x == x and self.fcn[i][1] | x != x:
x |= self.fcn[i][1]
if y | x == x:
return True
elapsed = 0
i = (i + 1) % len(self.fcn)
elapsed += 1
return False
def simplify(self):
i = 0
while i < len(self.fcn):
t = self.fcn[i]
self.fcn.pop(i)
if not self.check_fcn(t[0], t[1]):
self.fcn.insert(i, t)
i += 1
def add_fcn(self, x, y):
y = y & ~x
if not self.check_fcn(x, y):
self.fcn.append((x, y))
def check(self, x, y):
if isinstance(x, Comp):
x = self.index.get_mask(x)
if isinstance(y, Comp):
y = self.index.get_mask(y)
return self.check_fcn(x, y)
def __iadd__(self, other):
if isinstance(other, list):
for x in other:
self += x
return self
if isinstance(other, tuple):
t = list(other)
if isinstance(t[0], Comp):
self.index.record(t[0])
t[0] = self.index.get_mask(t[0])
if isinstance(t[1], Comp):
self.index.record(t[1])
t[1] = self.index.get_mask(t[1])
t[1] &= ~(t[0])
self.fcn.append(tuple(t))
elif isinstance(other, FcnRelation):
for a in other.fcn:
self += (other.index.from_mask(a[0]), other.index.from_mask(a[1]))
return self
elif isinstance(other, IBaseObj):
other.record_to(self.index)
return self
def __add__(self, other):
r = self.copy()
r += other
return r
def get_hc(self):
r = Expr.zero()
for a in self.fcn:
r += Expr.Hc(self.index.from_mask(a[1]), self.index.from_mask(a[0]))
return r
def get_region(self):
return self.get_hc() <= 0
def __bool__(self):
return bool(self.get_region())
class PartialOrder(IBaseObj):
"""Partial ordering"""
def __init__(self, comparator = None, terms = None):
self.comparator = comparator
self.terms = []
self.table = []
if terms is not None:
for x in terms:
self.add(x)
def copy(self):
r = PartialOrder()
r.comparator = iutil.copy(self.comparator)
r.terms = iutil.copy(self.terms)
r.table = [list(x) for x in self.table]
return r
def get_ancestors_id(self, x, descendant = False, invert = False):
n = len(self.terms)
if descendant:
for i in range(n):
if invert ^ self.table[i][x]:
yield i
else:
for i in range(n):
if invert ^ self.table[x][i]:
yield i
# def get_ancestors_id(self, x, descendant = False, include_self = True, invert = False):
# n = len(self.terms)
# vis = [False] * n
# i = x
# vis[i] = include_self
# cstack = [i]
# r = []
# while len(cstack):
# x = cstack.pop()
# if vis[x]:
# r.append(x)
# for y in (self.child[x] if descendant else self.parent[x]):
# if not vis[y]:
# vis[y] = True
# cstack.append(y)
# if invert:
# return [i for i in range(n) if not vis[i]]
# return r
# def get_descendants_id(self, x, **kwargs):
# return self.get_ancestors_id(x, descendant = True, **kwargs)
def compare(self, x, y):
if isinstance(x, tuple):
x = x[1]
if isinstance(y, tuple):
y = y[1]
if self.comparator is None:
return x <= y
if isinstance(self.comparator, Region):
return (self.comparator >> (x <= y))
return self.comparator(x, y)
def get_leader_id(self):
n = len(self.terms)
vis = [False] * n
r = [-1] * n
for i in range(n):
if vis[i]:
continue
minc = None
minj = -1
for j in range(i, n):
if self.table[i][j] and self.table[j][i]:
vis[j] = True
c = 0
t = self.terms[j]
if isinstance(t, tuple):
t = t[1]
if isinstance(t, (Expr, Region)):
c = t.complexity()
if minc is None or c < minc:
minc = c
minj = j
for j in range(i, n):
if self.table[i][j] and self.table[j][i]:
if j == minj:
r[j] = -1
else:
r[j] = minj
return r
def unique(self):
"""Unique entries in this partial ordering.
"""
n = len(self.terms)
leaders = self.get_leader_id()
return [self.terms[i] for i in range(n) if leaders[i] < 0]
def minimal(self, unique = True, flipped = False):
"""Minimal entries in this partial ordering.
"""
n = len(self.terms)
leaders = None
if unique:
leaders = self.get_leader_id()
r = []
for i in range(n):
if unique and leaders[i] >= 0:
continue
if flipped:
if any(self.table[i][j] and not self.table[j][i] for j in range(n)):
continue
else:
if any(self.table[j][i] and not self.table[i][j] for j in range(n)):
continue
r.append(self.terms[i])
return r
def maximal(self, unique = True, flipped = False):
"""Maximal entries in this partial ordering.
"""
return self.minimal(unique, not flipped)
def add(self, x):
verbose = PsiOpts.settings.get("verbose_partialorder", False)
n = len(self.terms)
pos = [[None, None] for i in range(n)]
leaders = self.get_leader_id()
def add_compare(i, sgn, val):
r = [list(t) for t in pos]
if val:
for j in self.get_ancestors_id(i, sgn == 0):
r[j][sgn] = True
for j in self.get_ancestors_id(i, sgn == 1, invert = True):
r[j][1 - sgn] = False
else:
for j in self.get_ancestors_id(i, sgn == 1):
r[j][sgn] = False
return r
def num_uncertain(cpos):
return sum(int(cpos[i][s] is None) for i in range(n) for s in range(2) if leaders[i] < 0)
if verbose:
print(", ".join(str(t) for t in self.terms))
print("ADD " + str(x))
print()
while True:
min_i = None
min_s = None
min_nu = None
min_pos0 = None
min_pos1 = None
for i in range(n):
if leaders[i] >= 0:
continue
for s in range(2):
if pos[i][s] is not None:
continue
tpos1 = add_compare(i, s, True)
tnu1 = num_uncertain(tpos1)
tpos0 = add_compare(i, s, False)
tnu0 = num_uncertain(tpos0)
if min_i is None or max(tnu0, tnu1) < min_nu:
min_i = i
min_s = s
min_nu = max(tnu0, tnu1)
min_pos0 = tpos0
min_pos1 = tpos1
if min_i is None:
break
res = None
if min_s == 0:
if verbose:
print(str(self.terms[min_i]) + " <= " + str(x) + " ?")
res = self.compare(self.terms[min_i], x)
else:
if verbose:
print(str(x) + " <= " + str(self.terms[min_i]) + " ?")
res = self.compare(x, self.terms[min_i])
if res:
pos = [list(t) for t in min_pos1]
else:
pos = [list(t) for t in min_pos0]
if verbose:
for i in range(n + 1):
st = ""
for j in range(n + 1):
c = None
if i == n:
if j == n:
c = True
else:
c = pos[j][1]
else:
if j < n:
c = self.table[i][j]
else:
c = pos[i][0]
st += " " +("?" if c is None else "1" if c else "0")
print(st)
print()
self.terms.append(x)
for i in range(n):
self.table[i].append(pos[i][0])
self.table.append([pos[i][1] for i in range(n)] + [True])
def hasse(self):
n = len(self.terms)
eq = [-1] * n
for i in range(n):
for j in range(i):
if self.table[i][j] and self.table[j][i]:
eq[i] = j
break
ch = [[] for i in range(n)]
for i in range(n):
if eq[i] >= 0:
continue
for j in range(n):
if i == j or eq[j] >= 0 or not self.table[i][j]:
continue
for k in range(n):
if k == i or k == j or eq[k] >= 0:
continue
if self.table[i][k] and self.table[k][j]:
break
else:
ch[i].append(j)
return (eq, ch)
def get_region(self):
eq, ch = self.hasse()
n = len(self.terms)
r = Region.universe()
for i in range(n):
if eq[i] >= 0:
r &= self.terms[eq[i]] == self.terms[i]
for j in ch[i]:
r &= self.terms[i] <= self.terms[j]
return r
def _latex_(self):
r = ""
with PsiOpts(repr_simplify = False):
r = self.get_region()._latex_()
return r
def tostring_term(self, i, style=None):
x = self.terms[i]
if isinstance(x, tuple):
x = x[0]
if style is not None and style == "graphviz":
return x.graphviz_text()
return str(x)
def tostring(self, style=None):
eq, ch = self.hasse()
n = len(self.terms)
r = ""
for i in range(n):
r += self.tostring_term(i, style=style)
if eq[i] >= 0:
r += " = " + self.tostring_term(eq[i], style=style)
elif ch[i]:
r += " <= " + ", ".join(self.tostring_term(j, style=style) for j in ch[i])
r += "\n"
return r
def __str__(self):
return self.tostring()
def __repr__(self):
return self.tostring()
def _repr_svg_(self):
if graphviz is None:
return None
try:
return self.graph()._repr_svg_()
except Exception as err:
return self.graph().build_dot()._repr_svg_()
def graph(self, shape = "plaintext", lr = False, ortho = False, arrowhead = "none", leaders_only = False, reverse = False, **kwargs):
"""Return the graphviz digraph of the Hasse diagram.
"""
if graphviz is None:
raise RuntimeError("Requires graphviz. Please install it first.")
eq, ch = self.hasse()
n = len(self.terms)
leaders = [-1] * n
if leaders_only:
leaders = self.get_leader_id()
r = graphviz.Digraph()
if lr:
r.graph_attr["rankdir"] = "LR"
if ortho:
r.graph_attr["splines"] = "ortho"
for key, value in kwargs.items():
r.graph_attr[key] = str(value)
names = ["" for i in range(n)]
for i in range(n):
if eq[i] >= 0:
continue
t = ""
for j in [i] + [k for k in range(n) if eq[k] == i]:
if leaders[j] < 0:
if t != "":
t += " = "
t += self.tostring_term(j, style="graphviz")
names[i] = t
for i in range(n):
r.node(names[i], names[i], shape = shape)
for i in range(n):
for j in ch[i]:
if reverse:
r.edge(names[i], names[j], arrowhead=arrowhead)
else:
r.edge(names[j], names[i], arrowhead=arrowhead)
return r
class BayesNet(IBaseObj):
"""Bayesian network"""
def __init__(self, edges = None):
self.index = IVarIndex()
self.parent = []
self.child = []
self.fcn = []
if edges is not None:
self += edges
def copy(self):
r = BayesNet()
r.index = self.index.copy()
r.parent = [list(x) for x in self.parent]
r.child = [list(x) for x in self.child]
r.fcn = list(self.fcn)
return r
def allcomp(self):
return self.index.comprv.copy()
def get_parents(self, x):
"""
Get the parents of node x.
Parameters
----------
x : Comp
Returns
-------
Comp
"""
i = self.index.get_index(x)
if i < 0:
return None
return sum((self.index.comprv[x] for x in self.parent[i]), Comp.empty())
def get_children(self, x):
"""
Get the children of node x.
Parameters
----------
x : Comp
Returns
-------
Comp
"""
i = self.index.get_index(x)
if i < 0:
return None
return sum((self.index.comprv[x] for x in self.child[i]), Comp.empty())
def get_ancestors(self, x, descendant = False, include_self = True):
"""
Get the ancestors of node x.
Parameters
----------
x : Comp
Returns
-------
Comp
"""
n = self.index.comprv.size()
vis = [False] * n
i = self.index.get_index(x)
vis[i] = include_self
cstack = [i]
r = Comp.empty()
while len(cstack):
x = cstack.pop()
if vis[x]:
r += self.index.comprv[x]
for y in (self.child[x] if descendant else self.parent[x]):
if not vis[y]:
vis[y] = True
cstack.append(y)
return r
def get_descendants(self, x, **kwargs):
"""
Get the descendants of node x.
Parameters
----------
x : Comp
Returns
-------
Comp
"""
return self.get_ancestors(x, descendant = True, **kwargs)
def edges(self):
"""
Generator over the edges of the network.
Yields
------
Pairs of Comp representing the edges.
"""
n = self.index.comprv.size()
for i in range(n):
for j in self.child[i]:
yield (self.index.comprv[i], self.index.comprv[j])
def add_edge_id(self, i, j):
if i < 0 or j < 0 or i == j:
return
if i not in self.parent[j]:
self.parent[j].append(i)
if j not in self.child[i]:
self.child[i].append(j)
def remove_edge_id(self, i, j):
if i < 0 or j < 0 or i == j:
return
self.parent[j].remove(i)
self.child[i].remove(j)
def record(self, x):
self.index.record(x)
n = self.index.comprv.size()
while len(self.parent) < n:
self.parent.append([])
while len(self.child) < n:
self.child.append([])
while len(self.fcn) < n:
self.fcn.append(False)
@fcn_list_to_list
def set_fcn(self, x, v = True):
"""Mark variables in x to be functions of their parents."""
self.record(x)
for xa in x.varlist:
i = self.index.get_index(xa)
if i >= 0:
self.fcn[i] = v
def is_fcn(self, x):
"""Query whether x is a function of their parents."""
for xa in x.varlist:
i = self.index.get_index(xa)
if i >= 0:
if not self.fcn[i]:
return False
return True
@fcn_list_to_list
def add_edge(self, x, y):
"""Add edges from every variable in x to every variable in y.
Also add edges among variables in y.
"""
self.record(x)
self.record(y)
for xa in x.varlist:
for ya in y.varlist:
self.add_edge_id(self.index.get_index(xa),
self.index.get_index(ya))
for yi in range(len(y.varlist)):
for yj in range(yi + 1, len(y.varlist)):
self.add_edge_id(self.index.get_index(y.varlist[yi]),
self.index.get_index(y.varlist[yj]))
return y
def __iadd__(self, other):
if isinstance(other, list):
for x in other:
self += x
return self
if isinstance(other, tuple):
for i in range(len(other) - 1):
self.add_edge(other[i], other[i + 1])
return self
elif isinstance(other, Term):
self.add_edge(other.z, other.x[0])
return self
elif isinstance(other, Comp):
self.add_edge(Comp.empty(), other)
return self
elif isinstance(other, BayesNet):
for x, y in other.edges():
self.add_edge(x, y)
for x in other.allcomp():
if other.is_fcn(x):
self.set_fcn(x)
return self
return self
def __add__(self, other):
r = self.copy()
r += other
return r
def join(self, other):
if not isinstance(other, BayesNet):
other = BayesNet([other.allcomp()])
r = self + other
for x in self.allcomp():
if self.get_children(x).isempty():
for y in other.allcomp():
if other.get_parents(y).isempty():
r += (x, y)
return r
def __floordiv__(self, other):
return self.join(other)
def __xor__(self, other):
return self + other
def communicate(self, a, b):
return self.get_ancestors(a).ispresent(b) and self.get_ancestors(b).ispresent(a)
def scc(self):
n = self.index.comprv.size()
r = BayesNet()
vis = [False] * n
for i in range(n):
if vis[i]:
continue
cgroup = Comp.empty()
cparent = Comp.empty()
for j in range(i, n):
if self.communicate(self.index.comprv[i], self.index.comprv[j]):
vis[j] = True
cgroup += self.index.comprv[j]
cparent += self.get_parents(self.index.comprv[j])
cparent -= cgroup
for k in range(len(cgroup)):
tparent = cparent + cgroup[:k]
r.add_edge(tparent, cgroup[k])
if self.is_fcn(cgroup[k]) and tparent.super_of(self.get_parents(cgroup[k])):
r.set_fcn(cgroup[k])
return r
def get_components(self):
n = self.index.comprv.size()
vis = [False] * n
cstack = []
r = []
for s in range(n):
if vis[s]:
continue
cstack = [s]
vis[s] = True
r.append(self.index.comprv[s])
while cstack:
i = cstack.pop()
for j in self.child[i] + self.parent[i]:
if not vis[j]:
cstack.append(j)
vis[j] = True
r[-1] += self.index.comprv[j]
return r
def indep_components(self):
n = self.index.comprv.size()
vis = [False] * n
r = []
did = True
while did:
did = False
for i in range(n):
if vis[i]:
continue
if len(self.parent[i]) == 0:
r.append(1 << i)
vis[i] = True
did = True
continue
for j in range(len(r)):
if all(r[j] & (1 << p) for p in self.parent[i]):
r[j] |= 1 << i
vis[i] = True
did = True
break
return [self.index.from_mask(m) for m in r]
def tsorted(self):
n = self.index.comprv.size()
cstack = []
cnparent = [0] * n
nrec = 0
r = BayesNet()
for i in range(n):
cnparent[i] = len(self.parent[i])
if cnparent[i] == 0:
heapq.heappush(cstack, i)
while len(cstack) > 0:
i = heapq.heappop(cstack)
r.record(self.index.comprv[i])
nrec += 1
for j in self.child[i]:
if cnparent[j] > 0:
cnparent[j] -= 1
if cnparent[j] == 0:
heapq.heappush(cstack, j)
if nrec < n:
return None
for i in range(n):
for j in self.parent[i]:
r.add_edge(self.index.comprv[j], self.index.comprv[i])
for i in range(n):
if self.fcn[i]:
r.set_fcn(self.index.comprv[i])
return r
def tsorted_stack(self):
n = self.index.comprv.size()
cstack = []
cnparent = [0] * n
nrec = 0
r = BayesNet()
for i in range(n - 1, -1, -1):
cnparent[i] = len(self.parent[i])
if cnparent[i] == 0:
cstack.append(i)
while len(cstack) > 0:
i = cstack.pop()
r.record(self.index.comprv[i])
nrec += 1
for j in reversed(self.child[i]):
if cnparent[j] > 0:
cnparent[j] -= 1
if cnparent[j] == 0:
cstack.append(j)
if nrec < n:
return None
for i in range(n):
for j in self.parent[i]:
r.add_edge(self.index.comprv[j], self.index.comprv[i])
for i in range(n):
if self.fcn[i]:
r.set_fcn(self.index.comprv[i])
return r
def reversed(self):
n = self.index.comprv.size()
r = BayesNet()
for x in reversed(self.index.comprv):
r += x
for i in range(n):
for j in self.parent[i]:
r.add_edge(self.index.comprv[j], self.index.comprv[i])
for i in range(n):
if self.fcn[i]:
r.set_fcn(self.index.comprv[i])
return r
def iscyclic(self):
return self.tsorted() is None
def contracted_node(self, x):
self_tsorted = self.tsorted()
if self_tsorted is not None:
self = self_tsorted
# self = self.tsorted()
k = self.index.get_index(x)
if k < 0:
return self.copy()
n = self.index.comprv.size()
r = BayesNet()
child = sorted(self.child[k])
for i0 in range(len(child)):
i = child[i0]
for j0 in range(i0 + 1, len(child)):
j = child[j0]
r.add_edge(self.index.comprv[i], self.index.comprv[j])
for i in range(n):
if i == k:
continue
for j in self.parent[i]:
if j == k:
for j2 in self.parent[k]:
if j2 == k:
continue
r.add_edge(self.index.comprv[j2], self.index.comprv[i])
continue
r.add_edge(self.index.comprv[j], self.index.comprv[i])
for i in range(n):
if i == k:
continue
if self.fcn[i]:
if self.fcn[k] or (k not in self.parent[i]):
r.set_fcn(self.index.comprv[i])
return r
def eliminated(self, x):
r = self.copy()
for a in x:
r = r.contracted_node(a)
return r
def __sub__(self, other):
return self.eliminated(other)
def check_hc_mask(self, x, z):
n = self.index.comprv.size()
cstack = []
vis = [False] * n
x &= ~z
for i in range(n):
if x & (1 << i):
cstack.append(i)
vis[i] = True
if z & (1 << i):
vis[i] = True
while cstack:
i = cstack.pop()
if not self.fcn[i]:
return False
for j in self.parent[i]:
if not vis[j]:
cstack.append(j)
vis[j] = True
return True
def fcn_descendants_mask(self, x):
n = self.index.comprv.size()
did = True
while did:
did = False
for i in range(n):
if self.fcn[i] and not x & (1 << i):
if all(x & j for j in self.parent[i]):
x |= 1 << i
did = True
return x
def fcn_descendants(self, x):
return self.index.from_mask(self.fcn_descendants_mask(self.index.get_mask(x)))
def check_ic_mask(self, x, y, z, icset = None):
if x < 0 or (icset is None and y < 0) or z < 0:
return False
z = self.fcn_descendants_mask(z)
x &= ~z
y &= ~z
if icset is None:
if x & y != 0:
# if not self.check_hc_mask(x & y, z):
# return False
# z |= x & y
# x &= ~z
# y &= ~z
return False
if x == 0 or y == 0:
return True
else:
if x == 0:
icset[0] = (1 << self.index.comprv.size()) - 1 - z
return True
n = self.index.comprv.size()
desc = z
cstack = []
for i in range(n):
if z & (1 << i) != 0:
cstack.append(i)
while len(cstack) > 0:
i = cstack.pop()
for j in self.parent[i]:
if desc & (1 << j) == 0:
desc |= (1 << j)
cstack.append(j)
vis = [0, x]
cstack = []
for i in range(n):
if x & (1 << i) != 0:
cstack.append((1, i))
cicset = 0
while len(cstack) > 0:
(d, i) = cstack.pop()
if icset is None:
if y & (1 << i) != 0:
return False
else:
cicset |= 1 << i
if z & (1 << i) == 0:
for j in self.child[i]:
if vis[0] & (1 << j) == 0:
vis[0] |= (1 << j)
cstack.append((0, j))
if (d == 0 and desc & (1 << i) != 0) or (d == 1 and z & (1 << i) == 0):
for j in self.parent[i]:
if vis[1] & (1 << j) == 0:
vis[1] |= (1 << j)
cstack.append((1, j))
if icset is not None:
icset[0] = (1 << n) - 1 - (x | z | cicset)
return True
def check_ic(self, icexpr):
for a, c in icexpr.terms:
if a.isihc2():
if not self.check_ic_mask(self.index.get_mask(a.x[0]),
self.index.get_mask(a.x[1]),
self.index.get_mask(a.z)):
return False
elif a.ishc():
if not self.check_hc_mask(self.index.get_mask(a.x[0]),
self.index.get_mask(a.z)):
return False
else:
return False
return True
def max_ic_set(self, x, z):
icset = [0]
self.check_ic_mask(self.index.get_mask(x), 0, self.index.get_mask(z), icset)
return self.index.from_mask(icset[0])
def relative_children_mask(self, x, y, children = True):
cstack = []
vis = 0
n = self.index.comprv.size()
for i in range(n):
if x & (1 << i):
cstack.append(i)
vis |= 1 << i
r = 0
while cstack:
i = cstack.pop()
if y & (1 << i):
r |= 1 << i
continue
for j in (self.child[i] if children else self.parent[i]):
if not vis & (1 << j):
vis |= 1 << j
cstack.append(j)
return r
def markov_blanket_mask(self, x, y):
y &= ~x
c = self.relative_children_mask(x, y) | x
p = self.relative_children_mask(c, y & ~c, children=False)
r = (c | p) & ~x
if self.check_ic_mask(x, y & ~r, r):
return r
return y
def markov_blanket(self, x, y = None):
if y is None:
y = self.index.comprv - x
r = self.markov_blanket_mask(self.index.get_mask(x), self.index.get_mask(y))
return self.index.from_mask(r)
def from_ic_inplace(self, icexpr, roots = None, add_var = None, add_hc = True):
n_root = 0
if roots is not None:
self.record(roots)
n_root = self.index.comprv.size()
if add_var is not None:
self.record(add_var)
ics = []
for (a, c) in icexpr.terms:
if a.isic2():
self.record(a.x[0])
self.record(a.x[1])
self.record(a.z)
x0 = self.index.get_mask(a.x[0])
x1 = self.index.get_mask(a.x[1])
z = self.index.get_mask(a.z)
x0 &= ~z
x1 &= ~z
x0 &= ~x1
ics.append((x0, x1, z))
elif add_hc and a.ishc():
self.record(a.x[0])
self.record(a.z)
x0 = self.index.get_mask(a.x[0])
z = self.index.get_mask(a.z)
x0 &= ~z
ics.append((x0, -1, z))
n = self.index.comprv.size()
ics = [(x0, x1 if x1 >= 0 else (1 << n) - 1 - x0 - z, z) for x0, x1, z in ics]
numcond = [0] * n
for i in range(n):
if add_var is not None and add_var.ispresent(self.index.comprv[i]):
continue
for x0, x1, z in ics:
if z & (1 << i):
numcond[i] += 1
ilist = list(range(n_root, n))
ilist.sort(key = lambda i: numcond[i])
n2 = n - n_root
xk = 0
zk = 0
vis = 0
np2 = (1 << n2)
dp = [100000000000] * np2
dpi = [-1] * np2
dped = [-1] * np2
dp[np2 - 1] = 0
for tvis in range(np2 - 2, -1, -1):
vis = (tvis << n_root) | ((1 << n_root) - 1)
nvis = iutil.bitcount(vis)
for i in ilist:
if vis & (1 << i) == 0:
nedge = nvis * (10000 - numcond[i]) + dp[tvis | (1 << (i - n_root))]
# nedge = nvis + dp[tvis | (1 << (i - n_root))]
if nedge < dp[tvis]:
dp[tvis] = nedge
dpi[tvis] = i
dped[tvis] = vis
for (x0, x1, z) in ics:
if z & ~vis != 0:
continue
if x0 & (1 << i) != 0:
xk = x1
zk = (z + x0 - (1 << i)) & vis
elif x1 & (1 << i) != 0:
xk = x0
zk = (z + x1 - (1 << i)) & vis
else:
continue
if vis & ~(zk | xk) != 0:
continue
nedge = iutil.bitcount(zk) * (10000 - numcond[i]) + dp[tvis | (1 << (i - n_root))]
# nedge = iutil.bitcount(zk) + dp[tvis | (1 << (i - n_root))]
if nedge < dp[tvis]:
dp[tvis] = nedge
dpi[tvis] = i
dped[tvis] = zk
#for vis in range(np2):
# print("{0:b}".format(vis) + " " + str(dp[vis]) + " " + str(dpi[vis]) + " " + "{0:b}".format(dped[vis]))
cvis = 0
for it in range(n2):
i = dpi[cvis]
ed = dped[cvis]
for j in range(n):
if ed & (1 << j) != 0:
self.add_edge_id(j, i)
cvis |= (1 << (i - n_root))
def from_ic(icexpr, roots = None, add_var = None, add_hc = True):
"""Construct Bayesian network from the sum of conditional mutual
information terms (Expr).
"""
r = BayesNet()
r.from_ic_inplace(icexpr, roots, add_var = add_var, add_hc = add_hc)
return r
def from_ic_list(icexpr, roots = None, add_var = None):
"""Construct a list of Bayesian networks from the sum of conditional
mutual information terms (Expr).
"""
r = []
icexpr = icexpr.copy()
while not icexpr.iszero():
t = BayesNet.from_ic(icexpr, roots = roots, add_var = add_var).tsorted()
olen = len(icexpr.terms)
icexpr.terms = [(a, c) for a, c in icexpr.terms if not t.check_ic(Expr.fromterm(a))]
icexpr.mhash = None
if len(icexpr.terms) == olen:
tl = BayesNet.from_ic_list(Expr.fromterm(icexpr.terms[0][0]), roots = roots, add_var = add_var)
icexpr.terms = [(a, c) for a, c in icexpr.terms if not any(t.check_ic(Expr.fromterm(a)) for t in tl)]
icexpr.mhash = None
r += tl
continue
r.append(t)
return r
def get_markov(self):
"""Get Markov chains as a list of lists.
"""
cs = self.tsorted()
n = cs.index.comprv.size()
r = []
def parent_min(i):
r = i
for j in cs.parent[i]:
r = min(r, parent_min(j))
return r
def parent_segment(i):
if not cs.parent[i]:
return -1
m = min(cs.parent[i])
if len(cs.parent[i]) != i - m:
return -1
for j in range(m + 1, i):
if set(cs.parent[j]) != set(cs.parent[m]).union(range(m, j)):
return -1
return m
def node_segment(i):
for j in range(i - 1, -1, -1):
if set(cs.parent[i]) != set(cs.parent[j]).union(range(j, i)):
return j + 1
return 0
def recur(st, en):
if st >= en:
return
i = en - 1
cms = [en, parent_min(i)]
if cms[-1] > st:
while cms[-1] > st:
t = parent_min(cms[-1] - 1)
cms.append(t)
tl = []
for i in range(len(cms) - 1):
if i:
tl.append([])
tl.append(list(range(cms[i + 1], cms[i])))
r.append(tl)
for i in range(len(cms) - 1):
recur(cms[i + 1], cms[i])
return
if len(cs.parent[i]) >= i - st:
recur(st, i)
return
m = node_segment(i)
t = [list(range(m, i + 1))]
i = m
while i >= st:
t.append(cs.parent[i])
m = parent_segment(i)
if m < 0:
break
i = m
t.append([j for j in range(st, i) if j not in cs.parent[i]])
while not t[-1]:
t.pop()
if len(t) >= 3:
r.append(t)
recur(st, i)
recur(0, n)
return [[sum((cs.index.comprv[a] for a in b), Comp.empty()) for b in reversed(tl)]
for tl in reversed(r)]
def get_ic_sorted(self):
n = self.index.comprv.size()
r = Expr.zero()
compvis = Comp.empty()
for i in range(n):
ps = Comp.empty()
for j in self.parent[i]:
ps += self.index.comprv[j]
y = compvis - ps
if y.size() > 0:
r += Expr.Ic(self.index.comprv[i], y, ps)
compvis += self.index.comprv[i]
return r
def get_ic_exhaust(self):
n = self.index.comprv.size()
r = Expr.zero()
for i in range(n):
ps = Comp.empty()
pmask = 0
for j in self.parent[i]:
ps += self.index.comprv[j]
pmask |= 1 << j
for mask in igen.subset_mask((1 << n) - 1 - (1 << i) - pmask):
if mask == 0:
continue
if self.check_ic_mask(1 << i, mask, pmask):
r += Expr.Ic(self.index.comprv[i], self.index.from_mask(mask), ps)
return r
def get_ic(self):
return self.tsorted().get_ic_sorted()
def get_region(self, exhaust = False):
n = self.index.comprv.size()
r = Expr.zero()
for i in range(n):
if self.fcn[i]:
ps = Comp.empty()
for j in self.parent[i]:
ps += self.index.comprv[j]
r += Expr.Hc(self.index.comprv[i], ps)
if exhaust:
r += self.get_ic_exhaust()
else:
r += self.get_ic()
reg = (r == 0)
reg.meta = dict()
reg.meta["rv_order"] = self.index.comprv.copy()
return reg
def example_bsc(self, **kwargs):
return Region.universe().example_bsc(bnet = self, **kwargs)
def assume(self):
"""Assume this Bayesian network is true in the current context.
"""
self.get_region().assume()
def assume_only(self):
"""Assume this Bayesian network is true in the current context. Overwrite existing assumptions.
"""
self.get_region().assume_only()
def assumed(self):
"""Create a context where this Bayesian network is assumed to be true.
Use "with bnet.assumed(): ..."
"""
return self.get_region().assumed()
def assumed_only(self):
"""Create a context where this Bayesian network is assumed to be true. Overwrite existing assumptions.
Use "with bnet.assumed_only(): ..."
"""
return self.get_region().assumed_only()
@latex_postprocess
def _latex_(self):
ms = self.get_markov()
eqnlist = [Region.markov_tostring(cm, PsiOpts.STR_STYLE_LATEX, False) for cm in ms]
if len(eqnlist) == 0:
return ""
if len(eqnlist) == 1:
return eqnlist[0]
return "\\begin{array}{l}\n" + "\\\\\n".join(eqnlist) + "\\\\\n\\end{array}"
def _repr_svg_(self):
if graphviz is None:
return None
try:
return self.graph()._repr_svg_()
except Exception as err:
return self.graph().build_dot()._repr_svg_()
def _repr_latex_(self):
if graphviz is not None:
return None
if PsiOpts.settings.get("repr_latex", False):
return self._latex_()
return None
def __bool__(self):
return bool(self.get_region())
def __or__(self, other):
return self.get_region() | other
def __and__(self, other):
return self.get_region() & other
def __lshift__(self, other):
return self.get_region() << other
def __rshift__(self, other):
return self.get_region() >> other
def get_basis(self, more_vars = None):
"""Get a basis of the entropy region of this Bayesian network (may not be minimal).
"""
return self.get_region().get_basis(more_vars = more_vars)
def tostring(self, tsort = True):
if tsort:
tself = self.tsorted()
if tself is not None:
return tself.tostring(tsort = False)
n = self.index.comprv.size()
r = ""
for i in range(n):
first = True
for j in self.parent[i]:
if not first:
r += ","
r += self.index.comprv.varlist[j].tostring()
first = False
r += " -> " + self.index.comprv.varlist[i].tostring() + ("*" if self.fcn[i] else "") + "\n"
return r
def __str__(self):
return self.tostring()
def __repr__(self):
return self.tostring()
def __hash__(self):
return hash(self.tostring())
def graph(self, tsort = True, shape = "plaintext", lr = True, groups = None, group_color = None, group_fillcolor = None, ortho = False, **kwargs):
"""Return the graphviz digraph of the network that can be displayed in the console.
"""
if graphviz is None:
raise RuntimeError("Requires graphviz. Please install it first.")
if tsort:
return self.tsorted().graph(tsort = False, shape = shape, lr = lr, groups = groups, group_color = group_color, group_fillcolor = group_fillcolor, ortho = ortho, **kwargs)
n = self.index.comprv.size()
r = graphviz.Digraph()
if lr:
r.graph_attr["rankdir"] = "LR"
if ortho:
r.graph_attr["splines"] = "ortho"
if group_color is None:
group_color = PsiOpts.settings.get("graph_group_color")
if group_fillcolor is None:
group_fillcolor = PsiOpts.settings.get("graph_group_fillcolor")
for key, value in kwargs.items():
r.graph_attr[key] = str(value)
if groups is None:
groups = []
# r.graph_attr["rank"] = "same"
remrv = self.index.comprv.copy()
for gi, g in enumerate(groups):
g = iutil.ensure_comp(g, ref=self)
with r.subgraph(name = "cluster_" + str(gi)) as rcs:
if group_color is not None:
rcs.attr(color = group_color)
if group_fillcolor is not None:
rcs.attr(style = "filled")
rcs.attr(fillcolor = group_fillcolor)
with rcs.subgraph(name = "graphcluster_" + str(gi)) as rs:
rs.attr(rank = "same")
# rs.attr(rankdir = "TB")
for c in g:
if not remrv.ispresent(c):
continue
remrv -= c
i = self.index.get_index(c)
rv_text = self.index.comprv[i]
if self.fcn[i]:
rv_text = rv_text.name_operation(("append", "*"))
rs.node(self.index.comprv[i].get_name(), rv_text.graphviz_text(), shape = shape)
for i in range(n):
if not remrv.ispresent(self.index.comprv[i]):
continue
rv_text = self.index.comprv[i]
if self.fcn[i]:
rv_text = rv_text.name_operation(("append", "*"))
r.node(self.index.comprv[i].get_name(), rv_text.graphviz_text(), shape = shape)
for i in range(n):
for j in self.parent[i]:
r.edge(self.index.comprv[j].get_name(), self.index.comprv[i].get_name())
return r
class ValIndex:
def __init__(self, v = None):
if v is None:
self.v = None
self.vmap = {}
else:
self.v = v
self.vmap = {}
for i, x in enumerate(v):
self.vmap[x] = i
def get_index(self, x):
return self.vmap.get(x, -1)
class ConcDist(IBaseObj):
"""Concrete distributions / conditional distributions of random variables."""
def convert_shape(x):
if x is None:
return tuple()
if isinstance(x, int):
return (x,)
if isinstance(x, Comp):
return x.get_shape()
if isinstance(x, ConcDist):
return x.shape_out
return tuple(x)
def convert_shape_pair(x):
if x is None:
return (tuple(), tuple())
if isinstance(x, int):
return (tuple(), (x,))
if isinstance(x, Term):
return (x.z.get_shape(), sum((t.get_shape() for t in x.x), tuple()))
if isinstance(x, Comp) or isinstance(x, ConcDist):
return (tuple(), ConcDist.convert_shape(x))
if isinstance(x, tuple) and (len(x) == 0 or isinstance(x[0], int)):
return (tuple(), tuple(x))
return (ConcDist.convert_shape(x[0]), ConcDist.convert_shape(x[1]))
def __init__(self, p = None, num_in = None, shape = None, shape_in = None, shape_out = None, isvar = False, randomize = False, isfcn = False, check_valid = False):
self.force_float = True
self.isvar = isvar
self.iscache = False
self.v = None
# if p is None and shape_in is None and shape_out is None:
# self.p = None
# return
self.expr = None
self.isfcn = isfcn
if isinstance(p, list) and iutil.hasinstance(p, Expr):
p = ExprArray(p)
if isinstance(p, ExprArray) and p.isconst():
p = p.to_numpy()
if isinstance(p, ExprArray):
self.expr = p
if num_in is None:
if shape_in is None:
num_in = 0
else:
num_in = len(shape_in)
if num_in is not None:
shape_in = p.shape[:num_in]
shape_out = p.shape[num_in:]
p = None
if self.isvar and torch is None:
raise RuntimeError("Requires pytorch. Please install it first.")
if shape is not None:
shape_in, shape_out = ConcDist.convert_shape_pair(shape)
if isinstance(p, ConcDist):
if num_in is None and shape is None and shape_in is None:
num_in = p.get_num_in()
p = p.p
self.sublens = []
if p is None:
self.shape_in = ConcDist.convert_shape(shape_in)
self.shape_out = ConcDist.convert_shape(shape_out)
self.shape = self.shape_in + self.shape_out
if randomize:
self.randomize()
else:
self.set_uniform()
else:
if isinstance(p, list):
p = numpy.array(p)
self.p = p
if num_in is None:
if shape_in is not None:
num_in = len(shape_in)
else:
num_in = 0
tshape = p.shape
self.shape_in = tshape[:num_in]
self.shape_out = tshape[num_in:]
self.shape = self.shape_in + self.shape_out
if self.isvar:
self.normalize()
else:
if check_valid and not self.isvalid():
raise ValueError("Invalid probability distribution. Must contain nonnegative entries that sum to 1.")
if isfcn:
self.clamp_fcn()
def copy(self):
r = ConcDist(p = iutil.copy(self.p), shape_in = self.shape_in, shape_out = self.shape_out, isvar = self.isvar, isfcn = self.isfcn)
r.force_float = self.force_float
r.p = iutil.copy(self.p)
r.expr = iutil.copy(self.expr)
r.v = None
r.copy_torch()
return r
def clamp_fcn(self, randomize = False):
rnd = PsiOpts.get_random()
for xs in itertools.product(*[range(x) for x in self.shape_in]):
mzs = None
m = -1.0
for zs in itertools.product(*[range(z) for z in self.shape_out]):
t = float(self.p[xs + zs])
if randomize:
t /= rnd.exponential()
self.p[xs + zs] = 0.0
if t > m:
m = t
mzs = zs
self.p[xs + mzs] = 1.0
self.copy_torch()
def fraction_snap(self, denom = None, eps = None):
for xs in itertools.product(*[range(x) for x in self.shape_in]):
for zs in itertools.product(*[range(z) for z in self.shape_out]):
self.p[xs + zs] = iutil.float_snap(float(self.p[xs + zs]), denom = denom, eps = eps)
self.normalize()
def istorch(self):
return torch is not None and isinstance(self.p, torch.Tensor)
def is_placeholder(self):
return self.p is None
def get_num_in(self):
return len(self.shape_in)
def card_out(self):
r = 1
for a in self.shape_out:
r *= a
return r
def __getitem__(self, key):
return self.p[key]
def __setitem__(self, key, value):
self.p[key] = value
def copy_(self, other):
"""Copy content of other to self.
"""
if isinstance(other, ConcDist):
self.p = other.p
else:
self.p = other
self.copy_torch()
def flattened_sublen(self):
shape_its = []
shape_out_sub = []
c = 0
sublens = list(self.sublens) + [len(self.shape_out) - sum(self.sublens)]
for l in sublens:
shape_out_sub.append(iutil.product(self.shape_out[c:c+l]))
c += l
shape_out_sub = tuple(shape_out_sub)
r = numpy.zeros(self.shape_in + shape_out_sub)
if torch is not None and isinstance(self.p, torch.Tensor):
r = torch.tensor(r, dtype=torch.float64)
for xs in itertools.product(*[range(x) for x in self.shape_in]):
for zs in itertools.product(*[range(z) for z in self.shape_out]):
ids = []
c = 0
for l in sublens:
t = 0
for i in range(c, c + l):
t = t * self.shape_out[i] + zs[i]
ids.append(t)
c += l
r[xs + tuple(ids)] = self.p[xs + zs]
r = ConcDist(r, num_in = self.get_num_in())
r.sublens = [1] * (len(shape_out_sub) - 1)
return r
def calc_torch(self):
if not self.isvar:
return
card_out = self.card_out()
self.p = torch.zeros(self.shape_in + self.shape_out, dtype=torch.float64)
for xs in itertools.product(*[range(x) for x in self.shape_in]):
s = 0.0
ci = 0
for zs in itertools.product(*[range(z) for z in self.shape_out]):
if ci < card_out - 1:
self.p[xs + zs] = self.v[xs + (ci,)]
s += self.v[xs + (ci,)]
else:
self.p[xs + zs] = 1.0 - s
ci += 1
def copy_torch(self):
if not self.isvar:
self.v = None
return
card_out = self.card_out()
# self.p = self.p.numpy()
if self.v is None:
self.v = numpy.zeros(self.shape_in + (card_out - 1,))
self.v = torch.tensor(self.v, dtype=torch.float64, requires_grad = True)
with torch.no_grad():
for xs in itertools.product(*[range(x) for x in self.shape_in]):
ci = 0
for zs in itertools.product(*[range(z) for z in self.shape_out]):
if ci < card_out - 1:
self.v[xs + (ci,)] = float(self.p[xs + zs])
ci += 1
self.calc_torch()
def normalize(self):
if self.isfcn:
self.clamp_fcn()
return
ceps = PsiOpts.settings["eps"]
ceps_d = PsiOpts.settings["opt_eps_denom"]
card_out = self.card_out()
if self.isvar and torch is not None and isinstance(self.p, torch.Tensor):
self.p = self.p.detach().numpy()
for xs in itertools.product(*[range(x) for x in self.shape_in]):
s = 0.0
for zs in itertools.product(*[range(z) for z in self.shape_out]):
s += self.p[xs + zs]
if s > ceps:
for zs in itertools.product(*[range(z) for z in self.shape_out]):
self.p[xs + zs] /= s
else:
for zs in itertools.product(*[range(z) for z in self.shape_out]):
self.p[xs + zs] = 1.0 / card_out
self.copy_torch()
def clamp(self):
if not self.isvar:
return
if self.isfcn:
self.clamp_fcn(randomize = True)
return
ceps = PsiOpts.settings["eps"]
ceps_d = PsiOpts.settings["opt_eps_denom"]
card_out = self.card_out()
vt = self.v.detach().numpy()
for xs in itertools.product(*[range(x) for x in self.shape_in]):
for z in range(card_out - 1):
t = vt[xs + (z,)]
if numpy.isnan(t):
self.randomize()
return
vt[xs + (z,)] = max(t, 0.0)
if card_out > 2:
while True:
s = 0.0
minpos = 1e20
numpos = 0
for z in range(card_out - 1):
if vt[xs + (z,)] > 0:
s += vt[xs + (z,)]
minpos = min(minpos, vt[xs + (z,)])
numpos += 1
if s <= 1.0 + ceps:
break
tored = (s - 1.0) / numpos
good = False
if tored <= minpos:
good = True
else:
tored = minpos
for z in range(card_out - 1):
if vt[xs + (z,)] > 0:
vt[xs + (z,)] -= tored
if good:
break
for z in range(card_out - 1):
vt[xs + (z,)] = min(vt[xs + (z,)], 1.0)
# for z in range(card_out - 1):
# self.v[xs + (z,)] = vt[xs + (z,)]
with torch.no_grad():
self.v.copy_(torch.tensor(vt, dtype=torch.float64))
# self.v.copy_(torch.tensor(vt))
# with torch.no_grad():
# for xs in itertools.product(*[range(x) for x in self.shape_in]):
# for z in range(card_out - 1):
# self.v[xs + (z,)] = vt[xs + (z,)]
# self.v = torch.tensor(self.v, requires_grad = True)
self.calc_torch()
def set_uniform(self):
card_out = self.card_out()
self.p = numpy.ones(self.shape_in + self.shape_out) * (1.0 / card_out)
if self.isvar:
self.normalize()
def randomize(self):
rnd = PsiOpts.get_random()
self.p = rnd.exponential(size = self.shape)
self.normalize()
def hop(self, prob):
rnd = PsiOpts.get_random()
if torch is not None and isinstance(self.p, torch.Tensor):
self.p = self.p.detach().numpy()
for xs in itertools.product(*[range(x) for x in self.shape_in]):
if rnd.uniform() >= prob:
continue
for zs in itertools.product(*[range(z) for z in self.shape_out]):
self.p[xs + zs] = rnd.exponential()
self.normalize()
def get_p(self):
return self.p
def get_v(self):
# if self.isfcn:
# return None
return self.v
def numpy(self):
"""Convert to numpy array."""
if iutil.istorch(self.p):
return self.p.detach().numpy()
return self.p
def torch(self):
"""Convert to torch.Tensor."""
if iutil.istorch(self.p):
return self.p
return torch.tensor(self.p, dtype=torch.float64)
def entropy(self):
"""Entropy of this distribution."""
ceps = PsiOpts.settings["eps"]
ceps_d = PsiOpts.settings["opt_eps_denom"]
loge = PsiOpts.settings["ent_coeff"]
istorch = torch is not None and isinstance(self.p, torch.Tensor)
r = 0.0
for xs in itertools.product(*[range(m) for m in self.shape]):
c = self.p[xs]
if istorch:
r -= c * torch.log((c + ceps_d) / (1.0 + ceps_d)) * loge
else:
if c > ceps:
r -= c * numpy.log(c) * loge
return ConcReal(r)
def items(self):
for xs in itertools.product(*[range(x) for x in self.shape_in]):
for zs in itertools.product(*[range(z) for z in self.shape_out]):
yield self.p[xs + zs]
def convert(x):
if not isinstance(x, ConcDist):
x = ConcDist(x)
return x
def __add__(self, other):
other = ConcDist.convert(other)
if (self.shape_in, self.shape_out) != (other.shape_in, other.shape_out):
raise ValueError("Shape mismatch.")
return
return ConcDist(self.p + other.p, num_in = self.get_num_in())
def __radd__(self, other):
return self + other
def __mul__(self, other):
if isinstance(other, int) or isinstance(other, float):
return ConcDist(self.p * float(other), num_in = self.get_num_in())
other = ConcDist.convert(other)
if self.shape_in != other.shape_in:
raise ValueError("Shape mismatch.")
return
r = None
if torch is not None and (isinstance(self.p, torch.Tensor) or isinstance(other.p, torch.Tensor)):
r = torch.zeros(self.shape_in + self.shape_out + other.shape_out, dtype=torch.float64)
else:
r = numpy.zeros(self.shape_in + self.shape_out + other.shape_out)
for xs in itertools.product(*[range(x) for x in self.shape_in]):
for zs in itertools.product(*[range(z) for z in self.shape_out]):
for ws in itertools.product(*[range(w) for w in other.shape_out]):
r[xs + zs + ws] += self.p[xs + zs] * other.p[xs + ws]
return ConcDist(r, num_in = self.get_num_in())
def __rmul__(self, other):
if isinstance(other, int) or isinstance(other, float):
return self * other
other = ConcDist.convert(other)
return other * self
def __pow__(self, other):
r = ConcDist(shape_in = self.shape_in, shape_out = tuple())
for i in range(other):
r = r * self
return r
def chan_product(self, other):
"""Product channel.
"""
other = ConcDist.convert(other)
r = None
if torch is not None and (isinstance(self.p, torch.Tensor) or isinstance(other.p, torch.Tensor)):
r = torch.zeros(self.shape_in + other.shape_in + self.shape_out + other.shape_out, dtype=torch.float64)
else:
r = numpy.zeros(self.shape_in + other.shape_in + self.shape_out + other.shape_out)
for xs in itertools.product(*[range(x) for x in self.shape_in]):
for x2s in itertools.product(*[range(x2) for x2 in other.shape_in]):
for zs in itertools.product(*[range(z) for z in self.shape_out]):
for ws in itertools.product(*[range(w) for w in other.shape_out]):
r[xs + x2s + zs + ws] += self.p[xs + zs] * other.p[x2s + ws]
return ConcDist(r, num_in = self.get_num_in() + other.get_num_in())
def chan_power(self, other):
"""n-product channel.
"""
r = ConcDist(shape_in = tuple(), shape_out = tuple())
for i in range(other):
r = r.chan_product(self)
return r
def __matmul__(self, other):
other = ConcDist.convert(other)
if self.shape_out != other.shape_in:
raise ValueError("Shape mismatch.")
return
r = None
if torch is not None and (isinstance(self.p, torch.Tensor) or isinstance(other.p, torch.Tensor)):
r = torch.zeros(self.shape_in + other.shape_out, dtype=torch.float64)
else:
r = numpy.zeros(self.shape_in + other.shape_out)
for xs in itertools.product(*[range(x) for x in self.shape_in]):
for zs in itertools.product(*[range(z) for z in self.shape_out]):
for ws in itertools.product(*[range(w) for w in other.shape_out]):
r[xs + ws] += self.p[xs + zs] * other.p[zs + ws]
return ConcDist(r, num_in = self.get_num_in())
def __truediv__(self, other):
if isinstance(other, int) or isinstance(other, float):
return ConcDist(self.p * float(1.0 / other), num_in = self.get_num_in())
if self.shape_in != other.shape_in:
raise ValueError("Shape mismatch.")
return
cshape_in = self.shape_out[:len(other.shape_out)]
if cshape_in != other.shape_out:
raise ValueError("Shape mismatch.")
return
cshape_out = self.shape_out[len(other.shape_out):]
ceps = PsiOpts.settings["eps"]
ceps_d = PsiOpts.settings["opt_eps_denom"]
zsn = 1
for k in cshape_out:
zsn *= k
cepsdzsn = ceps_d / zsn
r = None
if torch is not None and (isinstance(self.p, torch.Tensor) or isinstance(other.p, torch.Tensor)):
r = torch.zeros(self.shape_in + cshape_in + cshape_out, dtype=torch.float64)
for xs in itertools.product(*[range(x) for x in self.shape_in]):
for zs in itertools.product(*[range(z) for z in cshape_in]):
for ws in itertools.product(*[range(w) for w in cshape_out]):
r[xs + zs + ws] = (self.p[xs + zs + ws] + cepsdzsn) / (other.p[xs + zs] + ceps_d)
else:
r = numpy.zeros(self.shape_in + cshape_in + cshape_out)
for xs in itertools.product(*[range(x) for x in self.shape_in]):
for zs in itertools.product(*[range(z) for z in cshape_in]):
if other.p[xs + zs] > ceps:
for ws in itertools.product(*[range(w) for w in cshape_out]):
r[xs + zs + ws] = self.p[xs + zs + ws] / other.p[xs + zs]
else:
for ws in itertools.product(*[range(w) for w in cshape_out]):
r[xs + zs + ws] = 1.0 / zsn
return ConcDist(r, num_in = len(self.shape_in + cshape_in))
def semidirect(self, other, ids = None):
"""Semidirect product.
"""
if ids is None:
ids = range(len(other.shape_in))
r = None
if torch is not None and (isinstance(self.p, torch.Tensor) or isinstance(other.p, torch.Tensor)):
r = torch.zeros(self.shape_in + self.shape_out + other.shape_out, dtype=torch.float64)
else:
r = numpy.zeros(self.shape_in + self.shape_out + other.shape_out)
for xs in itertools.product(*[range(x) for x in self.shape_in]):
for zs in itertools.product(*[range(z) for z in self.shape_out]):
for ws in itertools.product(*[range(w) for w in other.shape_out]):
r[xs + zs + ws] = self.p[xs + zs] * other.p[tuple(zs[i] for i in ids) + ws]
return ConcDist(r, num_in = self.get_num_in())
def marginal(self, *args):
"""
Marginal distribution.
Parameters
----------
*args : int
Indices of the random variables of interest. E.g. for P(Y0,Y1,Y2|X),
P.marginal(0,2) gives P(Y0,Y2|X)
Returns
-------
ConcDist
The marginal distribution.
"""
ids = args
if isinstance(ids, int):
ids = [ids]
r = None
cshape = tuple(self.shape_out[i] for i in ids)
if torch is not None and isinstance(self.p, torch.Tensor):
r = torch.zeros(self.shape_in + cshape, dtype=torch.float64)
else:
r = numpy.zeros(self.shape_in + cshape)
for xs in itertools.product(*[range(x) for x in self.shape_in]):
for zs in itertools.product(*[range(z) for z in cshape]):
wrange = [range(w) for w in self.shape_out]
for i in range(len(cshape)):
if len(wrange[ids[i]]) == 1 and wrange[ids[i]][0] != zs[i]:
break
wrange[ids[i]] = [zs[i]]
else:
for ws in itertools.product(*wrange):
r[xs + zs] += self.p[xs + ws]
return ConcDist(r, num_in = self.get_num_in())
def reorder(self, ids):
r = None
cshape = tuple(self.shape_out[i] for i in ids)
if torch is not None and isinstance(self.p, torch.Tensor):
r = torch.zeros(self.shape_in + cshape, dtype=torch.float64)
else:
r = numpy.zeros(self.shape_in + cshape)
for xs in itertools.product(*[range(x) for x in self.shape_in]):
for zs in itertools.product(*[range(z) for z in self.shape_out]):
r[xs + tuple(zs[ids[i]] for i in range(len(ids)))] = self.p[xs + zs]
return ConcDist(r, num_in = self.get_num_in())
def given(self, *args):
"""
For a conditional distribution P(Y|X), give the distribution P(Y|X=x).
Parameters
----------
*args : int or None.
The values to substitute to X.
Must have the same number of arguments as the number of random variables
conditioned. Arguments are either int (value of RV) or None if the RV
is not substituted.
Returns
-------
ConcDist
The distribution after substitution.
"""
r = None
cshape = tuple(self.shape_in[i] for i in range(len(self.shape_in)) if args[i] is None)
if torch is not None and isinstance(self.p, torch.Tensor):
r = torch.zeros(cshape + self.shape_out, dtype=torch.float64)
else:
r = numpy.zeros(cshape + self.shape_out)
for xs in itertools.product(*[range(x) for x in cshape]):
xs2 = [0] * len(self.shape_in)
xsi = 0
for i in range(len(self.shape_in)):
if args[i] is None:
xs2[i] = xs[xsi]
xsi += 1
else:
xs2[i] = args[i]
xs2 = tuple(xs2)
for zs in itertools.product(*[range(z) for z in self.shape_out]):
r[xs + zs] = self.p[xs2 + zs]
return ConcDist(r, num_in = len(cshape))
def mean(self, f = None):
"""
Returns the expectation of the function f.
Parameters
----------
f : function, numpy.array or torch.Tensor
If f is a function, the number of arguments must match the number of
dimensions (random variables) of the joint distribution.
If f is an array or tensor, shape must match the shape of the
distribution.
Returns
-------
r : float or torch.Tensor
The expectation. Type is torch.Tensor if self or f is torch.Tensor.
"""
if f is None:
f = lambda x: x
r = None
for xs in itertools.product(*[range(x) for x in self.shape_in]):
for zs in itertools.product(*[range(z) for z in self.shape_out]):
if callable(f):
t = self.p[xs + zs] * f(*(xs + zs))
else:
t = self.p[xs + zs] * f[xs + zs]
if r is not None:
r += t
else:
r = t
return r
def __str__(self):
return str(self.p)
def __repr__(self):
r = ""
r += "ConcDist("
t = repr(self.p)
if t.find("\n") >= 0:
r += "\n"
r += t
if self.get_num_in() > 0:
r += ", num_in=" + str(self.get_num_in())
r += ")"
return r
@latex_postprocess
def _latex_(self):
t = ExprArray(self.p)
t.set_float(self.force_float)
return t._latex_()
def tostring(self, style = 0):
"""Convert to string.
Parameters:
style : Style of string conversion
STR_STYLE_STANDARD : I(X,Y;Z|W)
STR_STYLE_PSITIP : I(X+Y&Z|W)
"""
style = iutil.convert_str_style(style)
if style & PsiOpts.STR_STYLE_LATEX:
return self._latex_()
return str(self)
def fcn(fcncall, shape):
shape_in, shape_out = ConcDist.convert_shape_pair(shape)
return ConcDist(ExprArray.fcn(fcncall, shape_in + shape_out), shape_in = shape_in, shape_out = shape_out)
def det_fcn(fcncall, shape, isvar = False):
shape_in, shape_out = ConcDist.convert_shape_pair(shape)
shape_out = list(shape_out)
ys = []
for xs in itertools.product(*[range(x) for x in shape_in]):
t = fcncall(*xs)
# t = 0
# if len(xs) == 1:
# t = fcncall(xs[0])
# else:
# t = fcncall(xs)
if isinstance(t, (bool, float)):
t = int(t)
if isinstance(t, int):
t = (t,)
for i in range(len(shape_out)):
if shape_out[i] is None or shape_out[i] < t[i] + 1:
shape_out[i] = t[i] + 1
ys.append(t)
shape_out = tuple(shape_out)
p = numpy.zeros(shape_in + shape_out)
for xs, t in zip(itertools.product(*[range(x) for x in shape_in]), ys):
p[xs + t] = 1.0
return ConcDist(p, num_in = len(shape_in), isvar = isvar)
def isvalid(self):
"""Whether this distribution is valid.
"""
ceps = PsiOpts.settings["eps_check"]
for xs in itertools.product(*[range(x) for x in self.shape_in]):
csum = 0.0
for zs in itertools.product(*[range(z) for z in self.shape_out]):
c = float(self.p[xs + zs])
if c < -ceps:
return False
csum += c
if abs(csum - 1.0) > ceps:
return False
return True
def valid_region(self, skip_simplify = False):
"""For a symbolic distribution, returns the region where this is a
valid distribution.
"""
if self.expr is None:
return Region.universe()
r = Region.universe()
for xs in itertools.product(*[range(x) for x in self.shape_in]):
sumexpr = Expr.zero()
for zs in itertools.product(*[range(z) for z in self.shape_out]):
cexpr = self.expr[xs + zs]
r.iand_norename(cexpr >= 0)
sumexpr += cexpr
r.iand_norename(sumexpr == 1)
if not skip_simplify:
return r.simplified()
return r
def uniform(n, isvar = False):
"""n-ary uniform distribution."""
return ConcDist(numpy.ones(n) / n, isvar = isvar)
def bit(isvar = False):
"""Fair bit."""
return ConcDist.uniform(2, isvar = isvar)
def bern(a, isvar = False):
"""Bernoulli distribution."""
return ConcDist([1.0 - a, a], isvar = isvar)
def random(n, isvar = False):
"""n-ary random distribution."""
r = ConcDist(numpy.ones(n) / n, isvar = isvar)
r.randomize()
return r
def symm_chan(n, crossover, isvar = False):
"""n-ary symmetric channel."""
a = 1.0 - crossover
b = crossover / (n - 1)
# return ConcDist(numpy.ones((n, n)) * b + numpy.eye(n) * (a - b), num_in = 1, isvar = isvar)
return ConcDist([[a if i == j else b for j in range(n)] for i in range(n)],
num_in = 1, isvar = isvar)
def bsc(crossover, isvar = False):
"""Binary symmetric channel."""
return ConcDist.symm_chan(2, crossover, isvar = isvar)
def bin_chan(cross01, cross10, isvar = False):
"""Binary channel."""
return ConcDist([[1.0 - cross01, cross01], [cross10, 1.0 - cross10]], num_in = 1, isvar = isvar)
def erasure_chan(n, er_prob, isvar = False):
"""n-ary erasure channel."""
# return ConcDist(numpy.hstack([numpy.eye(n) * (1.0 - er_prob),
# numpy.ones((n, 1)) * er_prob]), num_in = 1, isvar = isvar)
return ConcDist([[(1.0 - er_prob) if i == j else 0.0 for j in range(n)] + [er_prob] for i in range(n)],
num_in = 1, isvar = isvar)
def bec(er_prob, isvar = False):
"""Binary erasure channel."""
return ConcDist.erasure_chan(2, er_prob, isvar = isvar)
def equal(shape_in, isvar = False):
"""Transition probability for two equal random vectors."""
if isinstance(shape_in, int):
shape_in = (shape_in,)
p = numpy.zeros(shape_in + shape_in)
for xs in itertools.product(*[range(x) for x in shape_in]):
p[xs + xs] = 1.0
return ConcDist(p, num_in = len(shape_in), isvar = isvar)
def flat(shape_in, isvar = False):
"""Transition probability for flattening a random vector into one random variable."""
if isinstance(shape_in, int):
shape_in = (shape_in,)
nout = iutil.product(shape_in)
p = numpy.zeros(shape_in + (nout,))
for xs in itertools.product(*[range(x) for x in shape_in]):
t = 0
for a, b in zip(xs, shape_in):
t = t * b + a
p[xs + (t,)] = 1.0
return ConcDist(p, num_in = len(shape_in), isvar = isvar)
def add(shape_in, isvar = False):
"""Transition probability from several random variables to their sum."""
if isinstance(shape_in, int):
shape_in = (shape_in,)
nout = sum(shape_in) - len(shape_in) + 1
p = numpy.zeros(shape_in + (nout,))
for xs in itertools.product(*[range(x) for x in shape_in]):
p[xs + (sum(xs),)] = 1.0
return ConcDist(p, num_in = len(shape_in), isvar = isvar)
def gaussian(r, l, isvar = False):
"""Quantized standard Gaussian distribution in the range [-r, r],
divided into l cells.
"""
sqrt2 = numpy.sqrt(2.0)
p = numpy.zeros(l)
cdf = 0.0
cdf0 = 0.0
for i in range(l + 1):
x = (i * 2.0 / l - 1.0) * r
cdf2 = 0.5 * (1 + scipy.special.erf(x / sqrt2))
if i > 0:
p[i - 1] = cdf2 - cdf
cdf = cdf2
if i == 0:
cdf0 = cdf
for i in range(l):
p[i] /= cdf - cdf0
return ConcDist(p, num_in = 0, isvar = isvar)
def convolve_kernel(shape_in, kernel, isvar = False):
"""Transition probability from X to X+Z, where X is a random vector with a
pmf of shape shape_in, and Z is independent of X and follows the distribution
given by kernel.
"""
if isinstance(shape_in, int):
shape_in = (shape_in,)
if isinstance(kernel, ConcDist):
kernel = kernel.p
shape_out = tuple(a + b - 1 for a, b in zip(shape_in, kernel.shape))
p = numpy.zeros(shape_in + shape_out)
for xs in itertools.product(*[range(x) for x in shape_in]):
for zs in itertools.product(*[range(x) for x in kernel.shape]):
p[xs + tuple(x + z for x, z in zip(xs, zs))] = kernel[zs]
return ConcDist(p, num_in = len(shape_in), isvar = isvar)
class ConcReal(IBaseObj):
"""Concrete real variable."""
def __init__(self, x = None, lbound = None, ubound = None, scale = 1.0, isvar = False, isint = False, randomize = False):
self.force_float = True
self.isvar = isvar
self.isint = isint
if self.isvar and torch is None:
raise RuntimeError("Requires pytorch. Please install it first.")
if x is None:
x = 0.0
if isinstance(x, int):
x = float(x)
if isinstance(x, ConcReal):
x = x.x
if isinstance(lbound, int):
lbound = float(lbound)
if isinstance(ubound, int):
ubound = float(ubound)
if isinstance(scale, int):
scale = float(scale)
self.x = x
self.v = None
self.lbound = lbound
self.ubound = ubound
self.scale = scale
self.copy_torch()
if randomize:
self.randomize()
def copy(self):
r = ConcReal(x = iutil.copy(self.x), lbound = self.lbound, ubound = self.ubound, scale = self.scale, isvar = self.isvar, isint = self.isint)
r.force_float = self.force_float
r.x = iutil.copy(self.x)
r.v = None
r.copy_torch()
return r
def const(x = 0.0):
"""Constant.
"""
return ConcReal(x, x, x)
def convert(x):
if isinstance(x, int) or isinstance(x, float):
x = ConcReal.const(x)
elif not isinstance(x, ConcReal):
x = ConcReal(x)
return x
def calc_torch(self):
if self.isvar:
self.x = self.v
def copy_torch(self):
if self.isvar:
if self.v is None:
self.v = torch.tensor(self.x, dtype=torch.float64, requires_grad = True)
with torch.no_grad():
self.v.copy_(torch.tensor(self.x, dtype=torch.float64))
self.x = self.v
def copy_(self, other):
"""Copy content of other to self.
"""
if isinstance(other, int):
other = float(other)
if isinstance(other, ConcReal):
self.x = other.x
else:
self.x = other
self.copy_torch()
def __add__(self, other):
other = ConcReal.convert(other)
return ConcReal(self.x + other.x,
None if self.lbound is None or other.lbound is None else self.lbound + other.lbound,
None if self.ubound is None or other.ubound is None else self.ubound + other.ubound,
self.scale + other.scale)
def __radd__(self, other):
return self + other
def __mul__(self, other):
other = ConcReal.convert(other)
lbound = None
ubound = None
for a in [self.lbound, self.ubound]:
for b in [other.lbound, other.ubound]:
if a is not None and b is not None:
t = a * b
if lbound is None or lbound > t:
lbound = t
if ubound is None or ubound < t:
ubound = t
return ConcReal(self.x * other.x, lbound, ubound, self.scale * other.scale)
def __rmul__(self, other):
return self * other
def __sub__(self, other):
return self + other * -1
def __rsub__(self, other):
return other + self * -1
def __neg__(self):
return self * -1
def __truediv__(self, other):
return self * (1 / other)
def __rtruediv__(self, other):
if (isinstance(other, int) or isinstance(other, float)) and other > 0:
other = float(other)
lbound = None
ubound = None
if self.lbound is not None and self.lbound > 0:
ubound = other / self.lbound
if self.ubound is not None:
lbound = other / self.ubound
if self.ubound is not None and self.ubound < 0:
lbound = other / self.ubound
if self.lbound is not None:
ubound = other / self.lbound
return ConcReal(other / self.x, lbound, ubound, other / self.scale)
return other * (1 / self)
def clamp(self):
if not self.isvar:
return
vt = float(self.v.detach().numpy())
if numpy.isnan(vt):
self.randomize()
return
if self.lbound is not None and vt < self.lbound:
vt = self.lbound
if self.ubound is not None and vt > self.ubound:
vt = self.ubound
if self.isint:
vt = round(vt)
with torch.no_grad():
self.v.copy_(torch.tensor(vt, dtype=torch.float64))
def randomize(self):
rnd = PsiOpts.get_random()
if self.lbound is None or self.ubound is None:
self.x = rnd.exponential() * self.scale
if self.ubound is not None:
self.x = self.ubound - self.x
elif self.lbound is not None:
self.x = self.lbound + self.x
else:
if rnd.uniform() < 0.5:
self.x *= -1
else:
self.x = rnd.uniform(self.lbound, self.ubound)
self.copy_torch()
def fraction_snap(self, denom = None, eps = None):
self.x = iutil.float_snap(float(self.x), denom = denom, eps = eps)
self.copy_torch()
def hop(self, prob):
rnd = PsiOpts.get_random()
if rnd.uniform() < prob:
self.randomize()
def get_x(self):
return self.x
def get_v(self):
return self.v
def __float__(self):
return float(self.x)
def __int__(self):
return int(round(float(self.x)))
def torch(self):
"""Convert to torch.Tensor."""
if iutil.istorch(self.x):
return self.x
return torch.tensor(self.x, dtype=torch.float64)
@staticmethod
def unbox(x):
if isinstance(x, ConcReal):
return x.x
return x
def __str__(self):
return str(self.x)
def __repr__(self):
r = ""
r += "ConcReal("
r += str(self.x)
if self.lbound is not None:
r += ", lbound=" + str(self.lbound)
if self.ubound is not None:
r += ", ubound=" + str(self.ubound)
r += ")"
return r
@latex_postprocess
def _latex_(self):
return iutil.float_tostr(float(self.x), style = PsiOpts.STR_STYLE_LATEX, force_float = self.force_float)
def tostring(self, style = 0):
"""Convert to string.
Parameters:
style : Style of string conversion
STR_STYLE_STANDARD : I(X,Y;Z|W)
STR_STYLE_PSITIP : I(X+Y&Z|W)
"""
style = iutil.convert_str_style(style)
if style & PsiOpts.STR_STYLE_LATEX:
return self._latex_()
return str(self)
class RealRV(IBaseObj):
"""Discrete real-valued random variable."""
def __init__(self, x, fcn = None, supp = None):
self.comp = x
self.fcn = fcn
self.supp = supp
@property
def x(self):
return self.comp
class SharedRVModel(IBaseObj):
def __init__(self, reg = None, addrv = None):
self.index_rv = IVarIndex()
self.rvvals = []
self.index_real = IVarIndex()
self.realvals = []
if reg is not None:
self.init_reg(reg, addrv)
def copy(self):
r = SharedRVModel()
r.index_rv = iutil.copy(self.index_rv)
r.rvvals = iutil.copy(self.rvvals)
r.index_real = iutil.copy(self.index_real)
r.realvals = iutil.copy(self.realvals)
return r
def init_reg(self, reg, addrv = None):
reg.record_to(self.index_rv)
if addrv is not None:
addrv.record_to(self.index_rv)
n = len(self.index_rv.comprv)
self.rvvals = [0.0] * (1 << n)
ics = reg.get_ic(include_ic = True, include_hc = True)
for xmask in range(1, 1 << n):
for a, c in ics.terms:
if all(self.index_rv.get_mask(x) & xmask for x in a.x) and not self.index_rv.get_mask(a.z) & xmask:
self.rvvals[xmask] = None
break
maskcount = sum(int(self.rvvals[xmask] is not None) for xmask in range(1, 1 << n))
for xmask in range(1, 1 << n):
if self.rvvals[xmask] is None:
self.rvvals[xmask] = 0.0
else:
self.rvvals[xmask] = 1.0 / maskcount
def set_real(self, x, v):
if isinstance(x, Expr):
x = x.allcomp()
x.record_to(self.index_real)
i = self.index_real.get_index(x)
while len(self.realvals) < len(self.index_real.compreal):
self.realvals.append(0.0)
self.realvals[i] = v
def get_real(self, x):
if isinstance(x, Expr):
x = x.allcomp()
i = self.index_real.get_index(x)
if i < 0:
return None
return self.realvals[i]
def get_H_mask(self, mask):
n = len(self.index_rv.comprv)
r = 0.0
for xmask in range(1, 1 << n):
if mask & xmask != 0:
r += self.rvvals[xmask]
return r
def get_val(self, expr):
r = 0.0
for (a, c) in expr.terms:
termType = a.get_type()
if termType == TermType.IC:
k = len(a.x)
for t in range(1 << k):
csgn = -1
mask = self.index_rv.get_mask(a.z)
for i in range(k):
if (t & (1 << i)) != 0:
csgn = -csgn
mask |= self.index_rv.get_mask(a.x[i])
if mask != 0:
r += self.get_H_mask(mask) * csgn * c
elif termType == TermType.REAL or termType == TermType.REGION:
if a.isone():
r += c
else:
t = self.get_real(a.x[0])
r += t * c
return r
def __getitem__(self, x):
if isinstance(x, Expr):
return self.get_val(x)
if isinstance(x, Region):
return x.evalcheck(self)
return None
def __setitem__(self, key, value):
if isinstance(key, Expr):
self.set_real(key.allcomp(), value)
def __call__(self, x):
return self[x]
def table(self, *args, **kwargs):
"""Plot the information diagram as a Karnaugh map.
"""
return Region.universe().table(*args, self, **kwargs)
def venn(self, *args, **kwargs):
"""Plot the information diagram as a Venn diagram.
Can handle up to 5 random variables (uses Branko Grunbaum's Venn diagram for n=5).
"""
return Region.universe().venn(*args, self, **kwargs)
class ConcModel(IBaseObj):
"""Concrete distributions of random variables and values of real variables."""
def __init__(self, bnet = None, istorch = None):
if istorch is None:
self.istorch = PsiOpts.settings["istorch"]
else:
self.istorch = istorch
if self.istorch and torch is None:
raise RuntimeError("Requires pytorch. Please install it first.")
if bnet is None:
self.bnet = BayesNet()
else:
self.bnet = bnet
v = self.bnet.index.comprv
n = len(v)
# self.ps = [None] * n
self.psmap = {}
self.psmap_cache = {}
self.psv = [None] * n
self.psmap_mask = {}
self.card = [None] * n
self.pt = None
self.index_real = IVarIndex()
self.realvars = []
self.opt_reg = None
def copy_shallow(self):
r = ConcModel()
r.istorch = self.istorch
r.bnet = self.bnet.copy()
r.psmap = dict(self.psmap)
r.psmap_cache = dict(self.psmap_cache)
r.psv = list(self.psv)
r.psmap_mask = dict(self.psmap_mask)
r.card = list(self.card)
r.index_real = self.index_real.copy()
r.realvars = list(self.realvars)
return r
def copy(self):
r = ConcModel()
r.istorch = self.istorch
r.bnet = self.bnet.copy()
r.psmap = iutil.copy(self.psmap)
r.psmap_cache = iutil.copy(self.psmap_cache)
r.psv = iutil.copy(self.psv)
r.psmap_mask = iutil.copy(self.psmap_mask)
r.card = iutil.copy(self.card)
r.index_real = self.index_real.copy()
r.realvars = iutil.copy(self.realvars)
return r
def set_real(self, x, v):
self.clear_cache()
if v == "var":
self.set_real(x, ConcReal(isvar = True))
return
if v == "var,rand":
self.set_real(x, ConcReal(isvar = True, randomize = True))
return
if isinstance(x, Expr):
x = x.allcomp()
x.record_to(self.index_real)
i = self.index_real.get_index(x)
while len(self.realvars) < len(self.index_real.compreal):
self.realvars.append(None)
self.realvars[i] = v
def get_real(self, x):
if isinstance(x, Expr):
x = x.allcomp()
i = self.index_real.get_index(x)
if i < 0:
return None
return self.realvars[i]
def clear_cache(self):
self.pt = None
# self.psmap = {key: item for key, item in self.psmap.items() if not item.iscache}
self.psmap_mask = {}
self.psmap_cache = {}
def comp_to_tuple(self, x):
r = []
for a in x:
t = self.bnet.index.get_index(a)
if t < 0:
return None
r.append(t)
return tuple(r)
def comp_to_pair(self, x):
if isinstance(x, Term):
return (self.comp_to_tuple(x.z),
sum((self.comp_to_tuple(t) for t in x.x), tuple()))
elif isinstance(x, Comp):
return (tuple(), self.comp_to_tuple(x))
else:
return (self.comp_to_tuple(x[0]), self.comp_to_tuple(x[1]))
def tuple_to_comp(self, y):
r = Comp.empty()
for a in y:
r += self.bnet.index.comprv[a]
return r
def comp_get_sublens(self, x):
if isinstance(x, Term):
r = [len(a) for a in x.x]
r.pop()
return r
return []
def set_prob(self, x, p):
if isinstance(p, str):
opt_split = p.split(",")
opt = None
randomize = False
isvar = False
isfcn = False
mode = ""
for copt in opt_split:
if copt == "var":
isvar = True
elif copt == "rand":
randomize = True
elif copt == "fcn":
isfcn = True
else:
mode = copt
mode = mode.lower()
if mode == "flat":
shape_in, shape_out = self.convert_shape_pair(x)
self.set_prob(x, ConcDist.flat(shape_in, isvar = isvar))
elif mode == "equal":
shape_in, shape_out = self.convert_shape_pair(x)
self.set_prob(x, ConcDist.equal(shape_in, isvar = isvar))
elif mode == "add":
shape_in, shape_out = self.convert_shape_pair(x)
self.set_prob(x, ConcDist.add(shape_in, isvar = isvar))
elif mode == "present" or mode == "c":
shape_in, shape_out = self.convert_shape_pair(x)
shape_in_max = max(shape_in[1:])
dist = ConcDist.det_fcn((lambda *a: a[a[0] + 1]), (shape_in, shape_in_max))
self.set_prob(x, dist)
elif mode == "past" or mode == "p":
shape_in, shape_out = self.convert_shape_pair(x)
shape_array = shape_in[1:]
n = len(shape_array)
sizes = [iutil.product(shape_array[:i]) for i in range(n)]
shifts = [sum(sizes[:i]) for i in range(n + 1)]
def toindex(q, *xs):
t = 0
for a, b in zip(xs[:q], shape_array[:q]):
t = t * b + a
return t + shifts[q]
dist = ConcDist.det_fcn(toindex, (shape_in, shifts[n]))
self.set_prob(x, dist)
elif mode == "future" or mode == "f":
shape_in, shape_out = self.convert_shape_pair(x)
shape_array = shape_in[1:]
n = len(shape_array)
sizes = [iutil.product(shape_array[i + 1 :]) for i in range(n)]
shifts = [sum(sizes[:i]) for i in range(n + 1)]
def toindex(q, *xs):
t = 0
for a, b in zip(xs[q + 1:], shape_array[q + 1:]):
t = t * b + a
return t + shifts[q]
dist = ConcDist.det_fcn(toindex, (shape_in, shifts[n]))
self.set_prob(x, dist)
else:
shape_in, shape_out = self.convert_shape_pair(x)
self.set_prob(x, ConcDist(shape_in = shape_in, shape_out = shape_out,
isvar = isvar, randomize = randomize, isfcn = isfcn))
return
if callable(p) and not isinstance(p, (ConcDist, list, ExprArray)):
dist = ConcDist.det_fcn(p, self.convert_shape_pair(x))
self.set_prob(x, dist)
return
self.bnet += x
cin, cout = self.comp_to_pair(x)
if isinstance(p, list):
if iutil.hasinstance(p, Expr):
p = ExprArray(p)
else:
p = numpy.array(p)
shape = p.shape
if len(shape) != len(cin) + len(cout):
raise ValueError("Number of dimensions of prob. table = " + str(len(shape))
+ " does not match number of variables = " + str(len(cin) + len(cout)) + ".")
return
for j in range(len(cin + cout)):
t = self.get_card_id((cin + cout)[j])
if t is not None and t != shape[j]:
raise ValueError("Length of dimension " + str(self.bnet.index.comprv[(cin + cout)[j]]) + " of prob. table = " + str(shape[j])
+ " does not match its cardinality = " + str(t) + ".")
return
while len(self.card) < len(self.bnet.index.comprv):
self.card.append(None)
while len(self.psv) < len(self.bnet.index.comprv):
self.psv.append(None)
for j in range(len(cin + cout)):
self.card[(cin + cout)[j]] = shape[j]
if not isinstance(p, ConcDist):
p = ConcDist(p, num_in = len(cin), check_valid = True)
self.psmap[(cin, cout)] = p
for k in cout:
self.psv[k] = (cin, cout)
self.clear_cache()
def set_series(self, x, q, y):
if isinstance(y, CompArray):
y = y.get_sum()
if len(x) > 0:
self[x[0] | q+y] = "present"
if len(x) > 1:
self[x[1] | q+y] = "past"
if len(x) > 2:
self[x[2] | q+y] = "future"
def calc_dist(self, p):
if p is None:
return
if p.expr is not None:
p.p = self[p.expr]
def get_prob_mask(self, mask):
t = self.psmap_mask.get(mask, None)
self.calc_dist(t)
if t is not None:
return t
n = len(self.bnet.index.comprv)
k = 0
while (1 << (k + 1)) <= mask:
k += 1
if self.psv[k] is None:
raise ValueError("Random variable " + str(self.bnet.index.comprv[k]) + " has unspecified distribution.")
return
tin, tout = self.psv[k]
tp = self.psmap[(tin, tout)]
self.calc_dist(tp)
tin_mask = 0
for a in tin:
tin_mask |= 1 << a
tout_mask = 0
for a in tout:
tout_mask |= 1 << a
mask1 = (mask | tin_mask) & ~tout_mask
p2 = None
if mask1 > 0:
p1 = self.get_prob_mask(mask1)
p2 = p1.semidirect(tp, [iutil.bitcount(mask1 & ((1 << i) - 1)) for i in tin])
else:
p2 = tp
idinv = [None] * n
ci = 0
for i in range(n):
if mask1 & (1 << i):
idinv[i] = ci
ci += 1
for i in tout:
idinv[i] = ci
ci += 1
p3 = p2.marginal(*[idinv[i] for i in range(n) if mask & (1 << i)])
self.psmap_mask[mask] = p3
return p3
def get_prob_pair(self, cin, cout):
t = self.psmap.get((cin, cout), None)
self.calc_dist(t)
if t is not None:
return t
t = self.psmap_cache.get((cin, cout), None)
self.calc_dist(t)
if t is not None:
return t
istorch = self.istorch
cinlen = [self.get_card_id(i) for i in cin]
coutlen = [self.get_card_id(i) for i in cout]
cin_mask = 0
for a in cin:
cin_mask |= 1 << a
cout_mask = 0
for a in cout:
cout_mask |= 1 << a
p1 = self.get_prob_mask(cin_mask | cout_mask)
p1 = p1.reorder([iutil.bitcount((cin_mask | cout_mask) & ((1 << i) - 1)) for i in cin + cout])
r = None
if cin_mask != 0:
p0 = self.get_prob_mask(cin_mask)
p0 = p0.reorder([iutil.bitcount(cin_mask & ((1 << i) - 1)) for i in cin])
r = p1 / p0
else:
r = p1
self.psmap_cache[(cin, cout)] = r
return r
def get_prob(self, x):
cin, cout = self.comp_to_pair(x)
if cin is None or cout is None:
raise ValueError("Some random variables are absent in the model.")
return
r = self.get_prob_pair(cin, cout)
if isinstance(r, ConcDist):
r.sublens = self.comp_get_sublens(x)
return r
def get_card_id(self, i):
if i >= len(self.card):
return None
return self.card[i]
def get_card(self, x):
i = self.bnet.index.get_index(x)
if i < 0:
return None
return self.get_card_id(i)
def get_card_default(self, x):
i = self.bnet.index.get_index(x)
if i < 0:
return x.get_card()
return self.get_card_id(i)
def convert_shape(self, x):
if x is None:
return tuple()
if isinstance(x, int):
return (x,)
if isinstance(x, Comp):
return tuple(self.get_card_default(a) for a in x)
if isinstance(x, ConcDist):
return x.shape_out
return tuple(x)
def convert_shape_pair(self, x):
if x is None:
return (tuple(), tuple())
if isinstance(x, int):
return (tuple(), (x,))
if isinstance(x, Term):
return (self.convert_shape(x.z), sum((self.convert_shape(t) for t in x.x), tuple()))
if isinstance(x, Comp) or isinstance(x, ConcDist):
return (tuple(), self.convert_shape(x))
if isinstance(x, tuple) and (len(x) == 0 or isinstance(x[0], int)):
return (tuple(), tuple(x))
return (self.convert_shape(x[0]), self.convert_shape(x[1]))
def get_H(self, x):
p = self.get_prob(x)
return p.entropy().x
def get_ent_vector(self, x):
n = len(x)
r = []
for mask in range(1 << n):
r.append(self.get_H(x.from_mask(mask)))
return r
def discover(self, x, eps = None, skip_simplify = False):
"""Discover conditional independence among variables in x.
"""
v = self.get_ent_vector(x)
return Region.ent_vector_discover_ic(v, x, eps, skip_simplify)
def get_region(self, vals = False):
r = self.bnet.get_region()
if vals:
all_rv = self.bnet.index.comprv.copy()
n = len(all_rv)
for mask in range(1, 1 << n):
x = all_rv.from_mask(mask)
r.iand_norename(Expr.H(x) == float(self.get_H(x)))
for y in self.allcomprealvar_exprlist():
r.iand_norename(y == float(self[y]))
return r
def allcomprv(self):
return self.bnet.index.comprv.copy()
def allcompreal(self):
return self.index_real.compreal.copy()
def allcomprealvar(self):
return self.index_real.compreal.copy()
def allcomp(self):
return self.allcomprv() + self.allcompreal()
def add_reg(self, reg, other_neighbors = None):
varlist = []
cons = Region.universe()
reg = reg.copy()
regallcomprv = reg.allcomprv()
if isinstance(reg, RegionOp):
reg = reg.tosimple()
if reg is None:
raise ValueError("User-defined information quantities with RegionOp constraints cannot be optimized.")
return None
card0 = PsiOpts.settings["opt_aux_card"]
reg.aux_strengthen(self.bnet.index.comprv, other_neighbors = other_neighbors)
reg.simplify_quick(zero_group = 0)
regcom = reg.copy()
regcom.iand_norename(regcom.completed_semigraphoid(max_iter = 10000))
# print(reg)
# print(regcom)
tbnet = regcom.get_bayesnet(roots = self.bnet.index.comprv.inter(regcom.allcomprv()),
skip_simplify = True)
# print(tbnet)
fcnreg = Region.universe()
clus = []
# print(reg)
# print(regcom)
# print(tbnet)
rv_all = self.bnet.index.comprv.copy()
for a in tbnet.index.comprv + regallcomprv:
if self.bnet.index.get_index(a) >= 0:
continue
pa = None
if tbnet.index.get_index(a) >= 0:
pa = tbnet.get_parents(a)
else:
pa = rv_all.copy()
rv_all += a
isfcn = reg.copy_noaux().implies(Expr.Hc(a, pa) <= 0)
if isfcn:
fcnreg.iand_norename(Expr.Hc(a, pa) <= 0)
# print(pa)
# print(a)
# print()
for i in range(len(clus)):
cpa, ca, cisfcn = clus[i]
if pa == cpa + ca and cisfcn == isfcn:
clus[i][1] += a
break
else:
clus.append([pa, a, isfcn])
for cpa, ca, cisfcn in clus:
ccards = []
for a in ca:
ccard = a.get_card()
if ccard is None:
ccard = card0
ccards.append(ccard)
p = ConcDist(shape = (tuple(self.get_card(t) for t in cpa), tuple(ccards)),
isvar = True, randomize = True, isfcn = cisfcn)
self[ca | cpa] = p
varlist.append(p)
# for a in tbnet.index.comprv + reg.allcomprv():
# if self.bnet.index.get_index(a) >= 0:
# continue
# ccard = a.get_card()
# if ccard is None:
# ccard = card0
# pa = None
# if tbnet.index.get_index(a) >= 0:
# pa = tbnet.get_parents(a)
# else:
# pa = self.bnet.index.comprv
# # print(pa)
# isfcn = reg.copy_noaux().implies(Expr.Hc(a, pa) <= 0)
# if isfcn:
# fcnreg.iand_norename(Expr.Hc(a, pa) <= 0)
# # print(reg)
# # print(a)
# # print(pa)
# # print(isfcn)
# p = ConcDist(shape = (tuple(self.get_card(t) for t in pa), (ccard,)),
# isvar = True, randomize = True, isfcn = isfcn)
# self[a | pa] = p
# varlist.append(p)
for a in reg.allcomprealvar_exprlist():
if self.get_real(a) is None:
t = ConcReal(isvar = True)
self[a] = t
varlist.append(t)
csreg = self.get_region() & cons & fcnreg
for ineq in reg:
if not csreg.implies(ineq):
cons.iand_norename(ineq)
csreg.iand_norename(ineq)
return (varlist, cons)
def get_val_regterm(self, term, esgn):
if term.fcncall is not None:
cargs = []
for a in term.fcnargs:
if isinstance(a, Term) or isinstance(a, Comp):
cargs.append(self.get_prob(a))
elif isinstance(a, Expr):
cargs.append(self.get_val(a, 0))
elif isinstance(a, Region):
cargs.append(1 if self[a] else 0)
elif a == "model":
cargs.append(self)
else:
cargs.append(a)
return term.get_fcneval(cargs)
t = term.get_reg_sgn_bds()
if t is None:
return None
reg, sgn, bds = t
if len(bds) == 0:
return None
if isinstance(reg, RegionOp):
reg = reg.tosimple()
if reg is None:
raise ValueError("User-defined information quantities with RegionOp constraints cannot be optimized.")
return None
rcomp = reg.allcomprv()
for b in bds:
rcomp += b.allcomprv_shallow()
if self.bnet.index.comprv.super_of(rcomp):
r = numpy.inf
for b in bds:
t = self.get_val(b, esgn * sgn) * sgn
if float(t) < float(r):
r = t
return r * sgn
card0 = PsiOpts.settings["opt_aux_card"]
reg.simplify_quick(zero_group = 0)
cs = self.copy_shallow()
varlist, cons = cs.add_reg(reg)
tvar = None
if len(bds) == 1:
tvar = bds[0]
else:
tvar = Expr.real("#TVAR_" + term.get_name())
tvar_r = ConcReal(isvar = True, randomize = True)
cs[tvar] = tvar_r
varlist.append(tvar_r)
for b in bds:
cons &= tvar * sgn <= b * sgn
retval = cs.optimize(tvar, varlist, cons, sgn = sgn)
self.opt_reg = cs.opt_reg
return retval
def get_val(self, expr, esgn = 0):
r = 0.0
for (a, c) in expr.terms:
termType = a.get_type()
if termType == TermType.IC:
k = len(a.x)
for t in range(1 << k):
csgn = -1
mask = self.bnet.index.get_mask(a.z)
for i in range(k):
if (t & (1 << i)) != 0:
csgn = -csgn
mask |= self.bnet.index.get_mask(a.x[i])
if mask != 0:
# r += self.get_H_mask(mask) * csgn * c
r += self.get_H(self.bnet.index.from_mask(mask)) * csgn * c
elif termType == TermType.REAL or termType == TermType.REGION:
if a.isone():
r += c
else:
t = self.get_real(a.x[0])
if t is None:
if termType == TermType.REGION:
t = self.get_val_regterm(a, esgn * (1 if c > 0 else -1))
if t is None:
return None
if isinstance(t, ConcReal):
t = t.x
r += t * c
return ConcReal(r)
def check_region(self, x, reqgt0 = None, reqgt0_reg = None):
if x.aux_present() and reqgt0 is None:
x = x.simplified_quick()
if not x.aux_present() and reqgt0 is None:
return x.evalcheck(self)
if isinstance(x, RegionOp) or x.imp_present():
if not x.auxi.isempty():
raise ValueError("Universally quantified random variables cannot be optimized.")
return False
clist = (~x).to_cause_consequence()
for imp, cons, auxi in clist:
# if auxi:
# raise ValueError("Universally quantified random variables cannot be optimized.")
# return False
# if isinstance(cons, RegionOp):
# raise ValueError("Consequences with union cannot be optimized.")
# return False
creqgt0 = None
if not cons.isempty():
creqgt0 = cons.expr_sum_violate()
# caux = x.aux.inter(imp.allcomprv() + cons.allcomprv())
# print(auxi)
# print(imp.exists(auxi))
# print(cons)
# print()
t = self.check_region(imp.exists(auxi), reqgt0 = creqgt0, reqgt0_reg = cons)
if t:
return True
return False
# print(x.aux)
# print(x.tostring())
cs = self.copy_shallow()
varlist, cons = cs.add_reg(x, other_neighbors = reqgt0_reg)
optval_cutoff = -numpy.inf
if reqgt0 is None:
reqgt0 = Expr.zero()
else:
optval_cutoff = PsiOpts.settings["eps_violate_cutoff"]
# print(self)
retval = cs.optimize(reqgt0, varlist, cons, sgn = 1, optval_cutoff = optval_cutoff)
self.opt_reg = cs.opt_reg
# print(self)
# print(retval)
# print(cs[rv("U")])
# print(cs[H(rv("U"))])
# print(x.evalcheck(cs))
# print()
# if not reqgt0.iszero():
# if not retval >= PsiOpts.settings["eps_check"]:
# return False
if reqgt0_reg is not None and reqgt0_reg.evalcheck(cs):
return False
return x.evalcheck(cs)
def __call__(self, x):
if isinstance(x, Expr):
return self.get_val(x)
elif isinstance(x, ExprArray):
if self.istorch:
# r = torch.tensor([self[a] for a in x], dtype=torch.float64)
# return torch.reshape(r, x.shape)
r = torch.hstack(tuple(iutil.ensure_torch(self[a]) for a in x))
return torch.reshape(r, x.shape)
else:
r = numpy.array([float(self[a]) for a in x])
return numpy.reshape(r, x.shape)
elif isinstance(x, Region):
return self.check_region(x)
return None
def __getitem__(self, x):
if isinstance(x, tuple):
x = Term.from_symbols(x)
if isinstance(x, tuple) or isinstance(x, Comp) or isinstance(x, Term):
return self.get_prob(x)
if isinstance(x, CompArray):
return self.get_prob(x.get_term())
if isinstance(x, Expr) and len(x) == 1:
a, c = x.terms[0]
if c == 1:
termType = a.get_type()
if termType == TermType.REAL or termType == TermType.REGION:
t = self.get_real(a.x[0])
if t is not None:
return t
return self(x)
def __setitem__(self, key, value):
if isinstance(key, tuple):
key = Term.from_symbols(key)
if isinstance(key, Expr):
self.set_real(key.allcomp(), value)
else:
self.set_prob(key, value)
def convert_torch_tensors(self, x):
r = []
r_dist = []
if isinstance(x, list):
for a in x:
tr, tdist = self.convert_torch_tensors(a)
r += tr
r_dist += tdist
elif torch is not None and isinstance(x, torch.Tensor):
r.append(x)
elif isinstance(x, Comp):
for a in x:
i = self.bnet.index.get_index(x)
if i >= 0:
tr, tdist = self.convert_torch_tensors(self.ps[i])
r += tr
r_dist += tdist
elif isinstance(x, Expr):
t = self.get_real(x)
if t is not None:
tr, tdist = self.convert_torch_tensors(t)
r += tr
r_dist += tdist
elif isinstance(x, ExprArray):
for x2 in x:
t = self.get_real(x2)
if t is not None:
tr, tdist = self.convert_torch_tensors(t)
r += tr
r_dist += tdist
elif isinstance(x, ConcDist) or isinstance(x, ConcReal):
t = x.get_v()
if t is not None and torch is not None and isinstance(t, torch.Tensor):
r.append(t)
if x.isvar:
r_dist.append(x)
return (r, r_dist)
def tensor_copy_list(x, y):
with torch.no_grad():
for i in range(len(y)):
if len(x) <= i:
x.append(y[i].detach().clone())
else:
x[i].copy_(y[i])
def get_tensor(self, x):
r = None
if isinstance(x, Expr):
r = self.get_val(x)
else:
r = x(self)
if isinstance(r, ConcReal):
r = r.x
return r
def tensor_list_to_array(varlist):
r = []
for v in varlist:
k = functools.reduce(lambda x, y: x*y, v.shape, 1)
r += list(numpy.reshape(v.detach().numpy(), (k,)))
for i in range(len(r)):
if numpy.isnan(r[i]):
r[i] = 0.0
return numpy.array(r)
def tensor_list_grad_to_array(varlist):
r = []
for v in varlist:
k = functools.reduce(lambda x, y: x*y, v.shape, 1)
if v.grad is None:
r += [0.0 for i in range(k)]
else:
r += list(numpy.reshape(v.grad.numpy(), (k,)))
for i in range(len(r)):
if numpy.isnan(r[i]):
r[i] = 0.0
return numpy.array(r)
def tensor_list_from_array(varlist, a):
c = 0
for v in varlist:
k = functools.reduce(lambda x, y: x*y, v.shape, 1)
with torch.no_grad():
v.copy_(torch.tensor(numpy.reshape(a[c:c+k], v.shape), dtype=torch.float64))
if v.grad is not None:
v.grad.data.zero_()
c += k
def tensor_list_get_bds(varlist, distlist, ismat):
c = 0
cons = []
num_cons = 0
cons_A_data = []
cons_A_r = []
cons_A_c = []
xsize = 0
for v in varlist:
k = functools.reduce(lambda x, y: x*y, v.shape, 1)
xsize += k
bds = [(-numpy.inf, numpy.inf) for i in range(xsize)]
for v in varlist:
k = functools.reduce(lambda x, y: x*y, v.shape, 1)
for d in distlist:
if d.get_v() is v:
if isinstance(d, ConcDist):
stride = functools.reduce(lambda x, y: x*y, d.shape_out, 1) - 1
if stride >= 2:
for c2 in range(c, c+k, stride):
if ismat:
cons_A_data += [1.0] * stride
cons_A_r += [num_cons] * stride
cons_A_c += list(range(c2, c2+stride))
num_cons += 1
else:
def get_fcn(i0, i1):
def fcn(x):
return 1.0 - sum(x[i0: i1])
return fcn
def get_jac(i0, i1):
def fcn(x):
r = numpy.zeros(xsize)
r[i0:i1] = -numpy.ones(i1-i0)
return r
return fcn
cons.append({
"type": "ineq",
"fun": get_fcn(c2, c2+stride),
"jac": get_jac(c2, c2+stride)
})
for i in range(c, c+k):
bds[i] = (0.0, 1.0)
elif isinstance(d, ConcReal):
if d.lbound is not None or d.ubound is not None:
bds[c] = (-numpy.inf if d.lbound is None else d.lbound,
numpy.inf if d.ubound is None else d.ubound)
break
c += k
if ismat:
if num_cons == 0:
return (bds, [])
# mat = scipy.sparse.csr_matrix((cons_A_data, (cons_A_r, cons_A_c)), shape = (num_cons, xsize))
mat = scipy.sparse.coo_matrix((cons_A_data, (cons_A_r, cons_A_c)), shape = (num_cons, xsize)).toarray()
lcons = scipy.optimize.LinearConstraint(mat, -numpy.inf, 1.0)
return (bds, [lcons])
else:
return (bds, cons)
def optimize(self, expr, vs, reg = None, sgn = 1, optimizer = None, learnrate = None, learnrate2 = None,
momentum = None, num_iter = None, num_iter2 = None, num_points = None,
num_hop = None, hop_temp = None, hop_prob = None,
alm_rho = None, alm_rho_pow = None, alm_step = None, alm_penalty = None,
eps_converge = None, eps_tol = None, optval_cutoff = None):
"""
Minimize/maximize expr with variables in the list vs, constrained in the region reg.
Uses a combination of SLSQP, gradient descent, Adam (or any optimizer with the pytorch interface)
and basin-hopping. Constraints are handled using augmented Lagrangian method.
Kraft, D. A software package for sequential quadratic programming. 1988. Tech. Rep. DFVLR-FB 88-28,
DLR German Aerospace Center - Institute for Flight Mechanics, Koln, Germany.
Kingma, Diederik P., and Jimmy Ba. "Adam: A method for stochastic optimization."
arXiv preprint arXiv:1412.6980 (2014).
Wales, David J.; Doye, Jonathan P. K. (1997). "Global Optimization by Basin-Hopping
and the Lowest Energy Structures of Lennard-Jones Clusters Containing up to 110 Atoms".
The Journal of Physical Chemistry A. 101 (28): 5111-5116.
Hestenes, M. R. (1969). "Multiplier and gradient methods". Journal of Optimization
Theory and Applications. 4 (5): 303-320.
Parameters
----------
expr : Expr
The expression to be minimized. Can either be an Expr object or a
callable accepting ConcModel as argument.
vs : list
The variables to be optimized over (list of ConcDist and/or ConcReal).
reg : Region, optional
The region where the variables are constrained. The default is None.
sgn : int, optional
Set to 1 for maximization, -1 for minimization. The default is 1.
optimizer : closure, optional
The optimizer. Either "sgd", "adam" or any function that returns a pytorch
optimizer given the list of variables as argument. The default is None.
learnrate : float, optional
Learning rate. The default is None.
learnrate2 : float, optional
Learning rate in phase 2. The default is None.
momentum : float, optional
Momentum. The default is None.
num_iter : int, optional
Number of iterations. The default is None.
num_iter2 : int, optional
Number of iterations in phase 2. The default is None.
num_points : int, optional
Number of random starting points. The default is None.
num_hop : int, optional
Number of hops for basin-hopping. The default is None.
hop_temp : float, optional
Temperature for basin-hopping. The default is None.
hop_prob : float, optional
Probability of hopping for basin-hopping. The default is None.
alm_rho : float, optional
Rho in augmented Lagrangian method. The default is None.
alm_rho_pow : float, optional
Multiply rho by this amount after each iteration. The default is None.
eps_converge : float, optional
Convergence is declared if the change is smaller than eps_converge. The default is None.
eps_tol : float, optional
Tolerance for equality and inequality constraints. The default is None.
optval_cutoff : float, optional
Stop the algorithm when the objective goes beyond this value. The default is None.
Returns
-------
r : float
The minimum value.
"""
verbose = PsiOpts.settings.get("verbose_opt", False)
verbose_step = PsiOpts.settings.get("verbose_opt_step", False)
verbose_step_var = PsiOpts.settings.get("verbose_opt_step_var", False)
if optimizer is None:
optimizer = PsiOpts.settings["opt_optimizer"]
if learnrate is None:
learnrate = PsiOpts.settings["opt_learnrate"]
if learnrate2 is None:
learnrate2 = PsiOpts.settings["opt_learnrate2"]
if momentum is None:
momentum = PsiOpts.settings["opt_momentum"]
if num_iter is None:
num_iter = PsiOpts.settings["opt_num_iter"]
if num_iter2 is None:
num_iter2 = PsiOpts.settings["opt_num_iter2"]
if num_points is None:
num_points = PsiOpts.settings["opt_num_points"]
if num_hop is None:
num_hop = PsiOpts.settings["opt_num_hop"]
if hop_temp is None:
hop_temp = PsiOpts.settings["opt_hop_temp"]
if hop_prob is None:
hop_prob = PsiOpts.settings["opt_hop_prob"]
if alm_rho is None:
alm_rho = PsiOpts.settings["opt_alm_rho"]
if alm_rho_pow is None:
alm_rho_pow = PsiOpts.settings["opt_alm_rho_pow"]
if alm_step is None:
alm_step = PsiOpts.settings["opt_alm_step"]
if alm_penalty is None:
alm_penalty = PsiOpts.settings["opt_alm_penalty"]
if eps_converge is None:
eps_converge = PsiOpts.settings["opt_eps_converge"]
if eps_tol is None:
eps_tol = PsiOpts.settings["opt_eps_tol"]
rnd = PsiOpts.get_random()
truth = PsiOpts.settings["truth"]
cs = self.copy_shallow()
if isinstance(expr, (int, float)):
expr = Expr.const(expr)
# if reg is not None or expr.isregtermpresent():
# tvaropt = Expr.real("#TMPVAROPT")
# reg = reg & (tvaropt * sgn <= expr * sgn)
# reg = reg.
if not isinstance(vs, list):
vs = [vs]
if reg is None:
reg = []
elif isinstance(reg, Region) or isinstance(reg, tuple):
reg = [reg]
elif not isinstance(reg, list):
reg = [(reg, ">=")]
for i in range(len(reg)):
if isinstance(reg[i], RegionOp):
reg[i] = reg[i].tosimple()
if reg[i] is None:
reg[i] = Region.universe()
if truth is not None:
truth_simple = truth.tosimple()
if truth_simple is not None:
reg.append(truth_simple)
for i in range(len(reg)):
if not cs.allcomprv().super_of(reg[i].allcomprv()):
tvarlist, tcons = cs.add_reg(reg[i])
reg[i] = tcons
vs += tvarlist
def str_tuple(x):
if x[0] == 0:
return iutil.tostr_verbose(x[1] * -sgn)
return iutil.tostr_verbose((x[0], x[1] * -sgn))
def tuple_copy1(x, y):
for i in range(len(y)):
if len(x) <= i:
x.append([None, y[i][1]])
else:
x[i][1] = y[i][1]
varlist, distlist = cs.convert_torch_tensors(vs)
cons = []
cons_init = []
scipy_optimizer = None
if optimizer == "sgd":
pass
elif optimizer == "adam":
pass
elif isinstance(optimizer, str):
scipy_optimizer = optimizer
for creg in reg:
if isinstance(creg, Region):
creg = creg.copy()
creg.simplify_redundant()
for a in creg.exprs_ge:
if a.isnonpos_ic2():
if self.bnet.check_ic(a):
continue
if scipy_optimizer is None:
slack = torch.tensor(0.0, dtype=torch.float64, requires_grad = True)
cons.append([a, 0.0, slack])
varlist.append(slack)
else:
cons.append([a, 0.0, True])
for a in creg.exprs_eq:
if a.isnonpos_ic2() or a.isnonneg_ic2():
if self.bnet.check_ic(a):
continue
cons.append([a, 0.0])
elif isinstance(creg, tuple):
if creg[1] == ">=":
if scipy_optimizer is None:
slack = torch.tensor(0.0, dtype=torch.float64, requires_grad = True)
cons.append([creg[0], 0.0, slack])
varlist.append(slack)
else:
cons.append([creg[0], 0.0, True])
# if reg.isuniverse():
# reg = None
ismat = False
bds = None
cons_list = None
cons_list_fcn = None
use_jac = False
def get_fcn(cexpr, mul, has_f, has_d):
def fcn(x):
# print(str(cexpr) + " " + str(x))
cs.clear_cache()
ConcModel.tensor_list_from_array(varlist, x)
for d in distlist:
d.calc_torch()
cval = cs.get_tensor(cexpr)
if mul != 1:
cval *= mul
# print(" " + str(float(cval)))
if not has_d:
return float(cval)
if isinstance(cval, float):
if has_f:
return (float(cval), ConcModel.tensor_list_grad_to_array(varlist) * 0.0)
else:
return ConcModel.tensor_list_grad_to_array(varlist) * 0.0
cval.backward()
if has_f:
return (float(cval), ConcModel.tensor_list_grad_to_array(varlist))
else:
return ConcModel.tensor_list_grad_to_array(varlist)
return fcn
if scipy_optimizer is not None:
ismat = (scipy_optimizer == "trust-constr")
use_jac = True
for cismat in ([False, True] if ismat else [False]):
bds, cons_list = ConcModel.tensor_list_get_bds(varlist, distlist, ismat = cismat)
# print(bds)
for a in cons:
cons_dict = {}
if len(a) >= 3:
cons_dict["type"] = "ineq"
else:
cons_dict["type"] = "eq"
cons_dict["fun"] = get_fcn(a[0], 1, True, False)
cons_dict["jac"] = get_fcn(a[0], 1, False, True)
if cismat:
cons_list.append(scipy.optimize.NonlinearConstraint(
cons_dict["fun"], 0.0, numpy.inf, jac = cons_dict["jac"]
))
else:
cons_list.append(cons_dict)
if not cismat:
cons_list_fcn = cons_list
tuple_copy1(cons_init, cons)
int_big = 100000000
r = (int_big, numpy.inf, int_big)
r_v = []
r_v_d = []
t = (int_big, numpy.inf, int_big)
num_iter_list = [num_iter]
if num_iter2 > 0:
num_iter_list.append(num_iter2)
if verbose:
print("======== model ========")
for a in cs.bnet.index.comprv:
print(str(cs.bnet.get_parents(a)) + " -> " + str(a) + " card=" + str(cs.get_card(a)))
print("======== probs ========")
for (cin, cout), cp in cs.psmap.items():
print(str(sum(cs.bnet.index.comprv[i] for i in cin)) + " -> "
+ str(sum(cs.bnet.index.comprv[i] for i in cout))
+ (" var" if cp.isvar else "") + (" fcn" if cp.isfcn else ""))
if sgn > 0:
print("======== maximize ========")
else:
print("======== minimize ========")
print(expr)
print("======== over ========")
for d in distlist:
if isinstance(d, ConcDist):
print("dist shape=" + str((d.shape_in, d.shape_out)))
elif isinstance(d, ConcDist):
print("real=" + str(d.x))
if len(cons):
print("======== constraints ========")
for a in cons:
if len(a) >= 3:
print(str(a[0]) + ">=0")
else:
print(str(a[0]) + "==0")
for cpass, cur_num_iter in enumerate(num_iter_list):
cur_num_points = 1
if cpass == 0:
cur_num_points = num_points
tocutoff = False
for ip in range(cur_num_points):
if PsiOpts.is_timer_ended():
break
if tocutoff:
break
if ip > 0:
for d in distlist:
d.randomize()
tuple_copy1(cons, cons_init)
cs.clear_cache()
cur_lr = learnrate
cur_num_hop = num_hop
if cpass > 0:
cur_lr = learnrate2
cur_num_hop = 1
for ih in range(cur_num_hop):
if PsiOpts.is_timer_ended():
break
if verbose_step:
print("Pass #" + str(cpass) + "/" + str(len(num_iter_list)) +
", Point #" + str(ip) + "/" + str(cur_num_points) + ", Hop #" + str(ih) + "/" + str(cur_num_hop))
r_start = []
r_start_d = []
v_start = t
if ih > 0:
ConcModel.tensor_copy_list(r_start, varlist)
tuple_copy1(r_start_d, cons)
for d in distlist:
d.hop(hop_prob)
tuple_copy1(cons, cons_init)
cs.clear_cache()
if scipy_optimizer is None:
if optimizer == "sgd":
cur_optimizer = torch.optim.SGD(varlist, lr = cur_lr, momentum = momentum)
elif optimizer == "adam":
cur_optimizer = torch.optim.Adam(varlist, lr = cur_lr)
else:
cur_optimizer = optimizer(varlist)
# for a in con_eq:
# a[1] = 0.0
# for a in con_ge:
# a[1] = 0.0
# a[2].copy_(0.0)
cr = (int_big, numpy.inf, int_big)
# lastval = (int_big, numpy.inf)
lastval = numpy.inf
cur_num_violate = 0
if scipy_optimizer is not None:
res = None
resx = None
resfun = None
#!!!!!!!!!!!!!!!!!!!!!!!!
# print(ConcModel.tensor_list_to_array(varlist))
# print(get_fcn(expr, -sgn, True, use_jac)(ConcModel.tensor_list_to_array(varlist)))
# print()
if len(varlist):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
opts = {"maxiter": cur_num_iter, "disp": verbose_step}
if scipy_optimizer == "trust-constr":
opts["initial_tr_radius"] = 0.001
if verbose_step:
opts["verbose"] = 3
res = scipy.optimize.minimize(
get_fcn(expr, -sgn, True, use_jac),
ConcModel.tensor_list_to_array(varlist),
method = scipy_optimizer,
jac = use_jac,
bounds = bds,
constraints = cons_list,
tol = eps_tol,
options = opts
)
resx = res.x
resfun = res.fun
else:
resx = numpy.array([])
resfun = get_fcn(expr, -sgn, True, use_jac)(numpy.array([]))[0]
ConcModel.tensor_list_from_array(varlist, resx)
bad = False
for d in distlist:
if isinstance(d, ConcDist) and d.isfcn:
d.clamp()
bad = True
if bad:
resx = ConcModel.tensor_list_to_array(varlist)
resfun = get_fcn(expr, -sgn, True, use_jac)(resx)[0]
num_violate = 0
sum_violate = 0.0
for a in cons_list_fcn:
t = a["fun"](resx)
if numpy.isnan(t):
sum_violate = numpy.inf
num_violate += 1
continue
if a["type"] == "eq":
if abs(t) > eps_tol:
sum_violate += abs(t)
num_violate += 1
else:
if t < -eps_tol:
sum_violate += -t
num_violate += 1
penalty = sum_violate * alm_penalty
score = resfun + penalty
if numpy.isnan(score):
score = numpy.inf
t = (0, score, num_violate)
cur_num_violate = num_violate
if verbose_step:
print((("pass=" + str(cpass + 1) + " ") if cpass > 0 else "")
+ "#pt=" + str(ip)
+ " val=" + str(resfun)
+ " violate=" + str((num_violate, sum_violate)) + " opt=" + str_tuple(r))
else:
for it in range(cur_num_iter + 1):
if PsiOpts.is_timer_ended():
break
# cur_optimizer.zero_grad()
cval = cs.get_tensor(expr)
if -sgn != 1:
cval *= -sgn
t_cval = float(cval)
if numpy.isnan(t_cval):
t_cval = numpy.inf
crho = numpy.power(alm_rho_pow, it // alm_step) * alm_rho
crho0 = numpy.power(alm_rho_pow, (it - 1) // alm_step) * alm_rho
num_violate = 0
sum_dual_step = 0.0
penalty = 0.0
for a in cons:
con_val = cs.get_tensor(a[0])
p_con_val = float(con_val)
if len(a) >= 3:
if p_con_val < -eps_tol:
num_violate += 1
penalty += -p_con_val * alm_penalty
else:
if abs(p_con_val) > eps_tol:
num_violate += 1
penalty += abs(p_con_val) * alm_penalty
if len(a) >= 3:
con_val -= a[2]
t_con_val = float(con_val)
if numpy.isnan(t_con_val):
t_con_val = 0.0
sum_dual_step += abs(t_con_val)
if it > 0 and it % alm_step == 0:
t = a[1] + crho0 * t_con_val
if not numpy.isnan(t):
a[1] = t
cval += con_val * a[1] + torch.square(con_val) * crho * 0.5
if verbose_step:
if len(a) >= 3:
# print(str(a[0]) + ">=0 : val=" + str(p_con_val) + " dual=" + str(a[1]))
print(str(a[0]) + ">=0 : val=" + iutil.tostr_verbose(p_con_val)
+ " dual=" + iutil.tostr_verbose(a[1]) + " slack=" + iutil.tostr_verbose(float(a[2])))
else:
print(str(a[0]) + "==0 : val=" + iutil.tostr_verbose(p_con_val)
+ " dual=" + iutil.tostr_verbose(a[1]))
t = (num_violate, t_cval + penalty, num_violate)
cur_num_violate = num_violate
# print(t)
cr = min(cr, t)
if not numpy.isnan(cr[1]) and cr < r:
r = cr
ConcModel.tensor_copy_list(r_v, varlist)
tuple_copy1(r_v_d, cons)
if verbose_step:
print((("pass=" + str(cpass + 1) + " ") if cpass > 0 else "")
+ "#pt=" + str(ip) + " #iter=" + str(it)
+ " val=" + str_tuple(t) + " opt=" + str_tuple(min(r, cr)))
if verbose_step_var:
for d in distlist:
print(d)
# if t[0] == lastval[0] and abs(t[1] - lastval[1]) <= eps_converge:
# break
if abs(t_cval - lastval) + sum_dual_step <= eps_converge:
break
if it == cur_num_iter:
break
# lastval = t
lastval = t_cval
cur_optimizer.zero_grad()
cval.backward()
cur_optimizer.step()
for d in distlist:
d.clamp()
for a in cons:
if len(a) >= 3:
if numpy.isnan(float(a[2])) or float(a[2]) < 0:
with torch.no_grad():
a[2].copy_(torch.tensor(0.0, dtype=torch.float64))
cs.clear_cache()
# print(t)
cr = min(cr, t)
if not numpy.isnan(cr[1]) and cr < r:
r = cr
ConcModel.tensor_copy_list(r_v, varlist)
tuple_copy1(r_v_d, cons)
toreject = False
if ih > 0:
accept_prob = 1.0
if numpy.isnan(t[1]):
accept_prob = 0.0
elif t > v_start:
accept_prob = numpy.exp((v_start[1] - t[1]) / hop_temp)
if verbose_step:
print("HOP #" + str(ih) + " from=" + str_tuple(v_start)
+ " to=" + str_tuple(t) + " prob=" + iutil.tostr_verbose(accept_prob))
if rnd.uniform() >= accept_prob:
toreject = True
if verbose_step:
print("HOP REJECT")
ConcModel.tensor_copy_list(varlist, r_start)
tuple_copy1(cons, r_start_d)
for d in distlist:
d.calc_torch()
t = v_start
cs.clear_cache()
if not toreject:
if verbose_step:
print("HOP ACCEPT")
# print(t)
if optval_cutoff is not None and t[2] == 0 and t[1] <= optval_cutoff * -sgn:
if verbose_step:
print("CUTOFF")
tocutoff = True
if tocutoff:
break
if len(r_v):
ConcModel.tensor_copy_list(varlist, r_v)
tuple_copy1(cons, r_v_d)
for d in distlist:
d.calc_torch()
cs.clear_cache()
self.opt_reg = cs
self.clear_cache()
return (r[1] if r[2] == 0 else numpy.inf) * -sgn
def minimize(self, *args, **kwargs):
"""
Maximize expr with variables in the list vs, constrained in the region reg.
Refer to optimize for details.
"""
return self.optimize(*args, sgn = -1, **kwargs)
def maximize(self, *args, **kwargs):
"""
Maximize expr with variables in the list vs, constrained in the region reg.
Refer to optimize for details.
"""
return self.optimize(*args, sgn = 1, **kwargs)
def opt_model(self):
return self.opt_reg
def get_bayesnet(self):
return self.bnet.copy()
def table(self, *args, **kwargs):
"""Plot the information diagram as a Karnaugh map.
"""
return universe().table(*args, self, **kwargs)
def venn(self, *args, **kwargs):
"""Plot the information diagram as a Venn diagram.
Can handle up to 5 random variables (uses Branko Grunbaum's Venn diagram for n=5).
"""
return universe().venn(*args, self, **kwargs)
def graph(self, **kwargs):
"""Return the Bayesian network among the random variables as a
graphviz digraph that can be displayed in the console.
"""
return self.get_bayesnet().graph(**kwargs)
def set_force_float(self, force_float):
for (cin, cout), t in self.psmap.items():
t.force_float = force_float
for x in self.realvars:
x.force_float = force_float
def fraction_snap(self, denom = None, eps = None):
for (cin, cout), t in self.psmap.items():
t.fraction_snap(denom = denom, eps = eps)
for x in self.realvars:
x.fraction_snap(denom = denom, eps = eps)
def tostring(self, style = 0):
"""Convert to string.
Parameters:
style : Style of string conversion
STR_STYLE_STANDARD : I(X,Y;Z|W)
STR_STYLE_PSITIP : I(X+Y&Z|W)
"""
style = iutil.convert_str_style(style)
r = ""
nlstr = "\n"
if style & PsiOpts.STR_STYLE_LATEX:
nlstr = "\\\\\n"
if style & PsiOpts.STR_STYLE_LATEX:
r += "\\begin{array}{l}\n"
for (cin, cout), t in self.psmap.items():
# self.calc_dist(t)
# if t is None:
# continue
cinrv = self.tuple_to_comp(cin)
coutrv = self.tuple_to_comp(cout)
r += "P("
r += coutrv.tostring(style)
if cinrv:
r += "|"
r += cinrv.tostring(style)
r += ")"
r += " = "
r += t.tostring(style)
r += nlstr
for i, x in enumerate(self.index_real.compreal):
r += x.tostring(style)
r += " = "
v = self.realvars[i]
if isinstance(v, (int, fractions.Fraction)):
r += str(v)
elif isinstance(v, float):
r += iutil.float_tostr(v)
else:
r += v.tostring(style)
r += nlstr
if style & PsiOpts.STR_STYLE_LATEX:
r += "\\end{array}"
return r
def __str__(self):
return self.tostring(PsiOpts.settings["str_style"])
@latex_postprocess
def _latex_(self):
return self.tostring(iutil.convert_str_style("latex"))
class SparseMat:
"""List of lists sparse matrix. Do NOT use directly"""
def __init__(self, width):
self.width = width
self.x = []
self.rowinfo = []
def copy(self):
r = SparseMat(self.width)
r.x = [list(a) for a in self.x]
r.rowinfo = iutil.copy(self.rowinfo)
return r
@staticmethod
def from_row(row, width):
r = SparseMat(width)
r.x.append(list(row))
r.rowinfo.append(None)
return r
@staticmethod
def from_dense_row(row):
ceps = PsiOpts.settings["eps"]
r = SparseMat(len(row))
x = []
for i in range(len(row)):
if abs(row[i]) > ceps:
x.append((i, row[i]))
r.x.append(x)
r.rowinfo.append(None)
return r
def iszero(self):
return all(len(a) == 0 for a in self.x)
def __iadd__(self, other):
if self.width != other.width or len(self.x) != len(other.x):
raise RuntimeError("SparseMat iadd shape mismatch.")
return self
ceps = PsiOpts.settings["eps"]
for a, b in zip(self.x, other.x):
for j in range(len(b)):
for i in range(len(a)):
if a[i][0] == b[j][0]:
csum = a[i][1] + b[j][1]
if abs(csum) > ceps:
a[i] = (a[i][0], csum)
else:
a.pop(i)
break
else:
a.append(b[j])
return self
def __add__(self, other):
r = self.copy()
r += other
return r
def __imul__(self, other):
for a in self.x:
for i in range(len(a)):
a[i] = (a[i][0], a[i][1] * other)
return self
def __mul__(self, other):
r = self.copy()
r *= other
return r
def __isub__(self, other):
self += other * -1
return self
def __sub__(self, other):
r = self.copy()
r += other * -1
return r
def ratio(self, other):
ceps = PsiOpts.settings["eps"]
r = None
for a, b in zip(self.x, other.x):
if len(a) != len(b):
return None
for ax, bx in zip(a, b):
if ax[0] != bx[0]:
return None
if abs(bx[1]) <= ceps:
if abs(ax[1]) <= ceps:
continue
else:
return None
t = ax[1] / bx[1]
if r is not None and abs(t - r) > ceps:
return None
r = t
return r
def addrow(self, info = None):
self.x.append([])
self.rowinfo.append(info)
def poprow(self):
self.x.pop()
self.rowinfo.pop()
def add_last_row(self, j, c):
self.x[len(self.x) - 1].append((j, c))
def extend(self, other):
self.width = max(self.width, other.width)
self.x += other.x
self.rowinfo += other.rowinfo
def simplify_row(self, i):
ceps = PsiOpts.settings["eps"]
self.x[i].sort()
t = self.x[i]
self.x[i] = []
cj = -1
cc = 0.0
for (j, c) in t:
if j == cj:
cc += c
else:
if abs(cc) > ceps:
self.x[i].append((cj, cc))
cj = j
cc = c
if abs(cc) > ceps:
self.x[i].append((cj, cc))
def simplify_last_row(self):
self.simplify_row(len(self.x) - 1)
def simplify(self):
for i in range(len(self.x)):
self.simplify_row(i)
def last_row_isempty(self):
return len(self.x[len(self.x) - 1]) == 0
def unique_row(self):
ir = 1.61803398875 - 0.1
p = 10007
rowmap = {}
for i in range(len(self.x)):
a = self.x[i]
h = 0
for (j, c) in a:
h = h * p + hash(c + j * ir)
if h in rowmap:
if a == self.x[rowmap[h]]:
a[:] = []
else:
rowmap[h] = i
self.rowinfo = [ri for (a, ri) in zip(self.x, self.rowinfo) if len(a) > 0]
self.x = [a for a in self.x if len(a) > 0]
def remove_empty_rows(self):
self.rowinfo = [ri for (a, ri) in zip(self.x, self.rowinfo) if len(a) > 0]
self.x = [a for a in self.x if len(a) > 0]
def row_dense(self, i):
r = [0.0] * self.width
for (j, c) in self.x[i]:
r[j] += c
return r
def nonzero_cols(self):
r = [False] * self.width
for i in range(len(self.x)):
for (j, c) in self.x[i]:
r[j] = True
return r
def mapcol(self, m, allowmiss = False):
for i in range(len(self.x)):
for k in range(len(self.x[i])):
j2 = m[self.x[i][k][0]]
if not allowmiss and j2 < 0:
return False
self.x[i][k] = (j2, self.x[i][k][1])
if allowmiss:
self.x[i] = [(j, c) for (j, c) in self.x[i] if j >= 0]
return True
def sumrows(self, m, remove = False):
r = []
for i2 in range(len(self.x)):
a = self.x[i2]
cr = 0.0
did = False
for i in range(len(a)):
if a[i][0] in m:
cr += m[a[i][0]] * a[i][1]
if remove:
a[i] = (-1, a[i][1])
did = True
if did:
self.x[i2] = [(j, c) for (j, c) in self.x[i2] if j >= 0]
r.append(cr)
return r
def tolil(self):
r = scipy.sparse.lil_matrix((len(self.x), self.width))
for i in range(len(self.x)):
for (j, c) in self.x[i]:
r[i, j] += c
return r
def tonumpyarray(self):
r = numpy.zeros((len(self.x), self.width))
for i in range(len(self.x)):
for (j, c) in self.x[i]:
r[i, j] += c
return r
def tonumpyarray_row(self):
r = numpy.zeros(self.width)
for (j, c) in self.x[0]:
r[j] += c
return r
class LinearProg:
"""A linear programming instance. Do NOT use directly"""
def __init__(self, index, lptype, bnet = None, lp_bounded = None, save_res = False, prereg = None, dual_enabled = None, val_enabled = None, dual_form = None):
self.index = index
self.lptype = lptype
self.nvar = 0
self.nxvar = 0
self.xvarid = []
self.realshift = 0
self.cellpos = []
self.bnet = bnet
self.pinfeas = False
self.constmap = {}
save_enabled = PsiOpts.settings.get("lp_save_enabled", False)
if not save_enabled:
save_res = False
self.quantum = PsiOpts.settings.get("quantum", False)
self.hge0 = PsiOpts.settings.get("hge0", False)
self.hcge0 = PsiOpts.settings.get("hcge0", False)
self.ige0 = PsiOpts.settings.get("ige0", False)
self.icge0 = PsiOpts.settings.get("icge0", False)
if self.quantum or not self.hge0 or not self.hcge0 or not self.ige0 or not self.icge0:
self.lptype = LinearProgType.H
self.noskip = (PsiOpts.settings["proof_noskip"] and PsiOpts.settings["proof_enabled"])
if dual_form is None:
self.dual_form = PsiOpts.settings["lp_dual_form"] or (PsiOpts.settings["lp_dual_form_if_proof"] and PsiOpts.settings["proof_enabled"])
else:
self.dual_form = dual_form
self.dual_form_ncons = 0
self.dual_form_obj = None
self.dual_form_cutoff = 0.0
self.dual_form_infeas = False
self.dual_form_weights = None
self.dual_discover_exprs = []
self.dual_discover_init_model = None
self.dual_discover_l1 = False
self.vdiscover = []
self.vdiscover_bounds = []
self.vdiscover_disabled = set()
if PsiOpts.settings["lp_dual_form_discover"]:
self.dual_form_obj = "discover"
self.cond_maximize = PsiOpts.settings["lp_cond_maximize"]
if lp_bounded is None:
self.lp_bounded = PsiOpts.settings["lp_bounded"]
else:
self.lp_bounded = lp_bounded
self.lp_ubound = PsiOpts.settings["lp_ubound"]
self.lp_eps = PsiOpts.settings["lp_eps"]
self.lp_eps_obj = PsiOpts.settings["lp_eps_obj"]
self.zero_cutoff = PsiOpts.settings["lp_zero_cutoff"]
self.eps_present = False
self.affine_present = False
self.fcn_mode = PsiOpts.settings["fcn_mode"]
self.fcn_list = []
self.save_res = save_res
self.saved_var = []
self.celllink = []
self.cellsplit = []
if prereg is not None:
if self.fcn_mode >= 1:
for x in prereg.exprs_gei:
self.addExpr_ge0_fcn(x)
for x in prereg.exprs_eqi:
self.addExpr_ge0_fcn(x)
if self.lptype == LinearProgType.H:
self.nvar = (1 << self.index.num_rv()) - 1 + self.index.num_real()
self.realshift = (1 << self.index.num_rv()) - 1
elif self.lptype == LinearProgType.HMIN:
n = self.index.num_rv()
nbnet = bnet.index.num_rv()
self.celllink = [-2] * (1 << n)
self.cellsplit = [None] * (1 << n)
self.cellpos = [-2] * (1 << n)
cpos = 0
for mask in range(1, 1 << n):
mask2 = self.fcn_mask_maximize(mask)
if self.celllink[mask2] < 0:
self.celllink[mask2] = mask
self.celllink[mask] = mask
else:
self.celllink[mask] = self.celllink[mask2]
for mask in range(1, 1 << n):
if mask < (1 << nbnet) and mask == self.celllink[mask]:
for i in range(n):
if not mask & (1 << i):
continue
maskb = bnet.markov_blanket_mask(1 << i, mask - (1 << i))
if maskb == mask - (1 << i):
continue
self.cellsplit[mask] = (1 << i, mask - (1 << i) - maskb, maskb)
break
if mask == self.celllink[mask] and self.cellsplit[mask] is None:
self.cellpos[mask] = cpos
cpos += 1
self.realshift = cpos
self.nvar = self.realshift + self.index.num_real()
elif self.lptype == LinearProgType.HC1BN:
n = self.index.num_rv()
nbnet = bnet.index.num_rv()
self.cellpos = [-2] * (1 << n)
cpos = 0
if self.cond_maximize:
for i in range(n):
for mask in range((1 << i) - 1, -1, -1):
maski = mask + (1 << i)
if self.fcn_mode >= 1:
mask2 = self.fcn_mask_maximize(mask)
if mask2 & (1 << i):
self.cellpos[maski] = -1
continue
mask2 &= (1 << i) - 1
if mask2 != mask:
if mask2 < mask:
print("HC1BN REVERSED!")
self.cellpos[maski] = self.cellpos[mask2 + (1 << i)]
continue
if i < nbnet:
for j in range(i):
if mask & (1 << j) == 0:
if bnet.check_ic_mask(1 << i, 1 << j, mask):
self.cellpos[maski] = self.cellpos[maski + (1 << j)]
break
if self.cellpos[maski] == -2:
self.cellpos[maski] = cpos
cpos += 1
else:
for i in range(n):
for mask in range(1 << i):
maski = mask + (1 << i)
if self.fcn_mode >= 1:
# mask_max = self.fcn_mask_maximize(mask)
# if mask_max & (1 << i):
# self.cellpos[maski] = -1
# continue
# mask_min = self.fcn_mask_minimize(mask_max)
# if mask_min < mask:
# maskj = mask_min + (1 << i)
# self.cellpos[maski] = self.cellpos[maskj]
# continue
if self.checkfcn_mask(1 << i, mask):
self.cellpos[maski] = -1
continue
for j in range(i):
if mask & (1 << j) != 0:
if self.checkfcn_mask(1 << j, mask - (1 << j)):
maskj = mask - (1 << j) + (1 << i)
self.cellpos[maski] = self.cellpos[maskj]
break
if self.cellpos[maski] != -2:
continue
if i >= nbnet:
self.cellpos[maski] = cpos
cpos += 1
continue
for j in range(i):
if mask & (1 << j) != 0:
if bnet.check_ic_mask(1 << i, 1 << j, mask - (1 << j)):
maskj = mask - (1 << j) + (1 << i)
self.cellpos[maski] = self.cellpos[maskj]
break
if self.cellpos[maski] == -2:
self.cellpos[maski] = cpos
cpos += 1
self.realshift = cpos
self.nvar = self.realshift + self.index.num_real()
self.nxvar = self.nvar
self.Au = SparseMat(self.nvar)
self.Ae = SparseMat(self.nvar)
self.solver = None
if self.nvar <= PsiOpts.settings["solver_scipy_maxsize"]:
self.solver = iutil.get_solver("scipy")
else:
self.solver = iutil.get_solver()
self.solver_param = {}
self.bu = []
self.be = []
self.icp = [[] for i in range(self.index.num_rv())]
self.dual_u = None
self.dual_e = None
self.dual_pf = (PsiOpts.settings["proof_enabled"]
and not PsiOpts.settings["proof_nowrite"])
if dual_enabled is not None:
self.dual_enabled = dual_enabled
else:
self.dual_enabled = PsiOpts.settings["proof_enabled"]
self.val_x = None
if val_enabled is not None:
self.val_enabled = val_enabled
else:
self.val_enabled = False
self.optval = None
def get_optval(self):
return self.optval
def addreal_id(self, A, k, c):
A.add_last_row(self.realshift + k, c)
def addH_mask(self, A, mask, c):
if self.lptype == LinearProgType.H:
A.add_last_row(mask - 1, c)
elif self.lptype == LinearProgType.HC1BN:
n = self.index.num_rv()
for i in range(n):
if mask & (1 << i) != 0:
cp = self.cellpos[mask & ((1 << (i + 1)) - 1)]
if cp >= 0:
A.add_last_row(cp, c)
elif self.lptype == LinearProgType.HMIN:
mask = self.celllink[mask]
if self.cellsplit[mask] is None:
A.add_last_row(self.cellpos[mask], c)
else:
ca, cb, cc = self.cellsplit[mask]
self.addH_mask(A, ca | cc, c)
self.addH_mask(A, cb | cc, c)
if cc:
self.addH_mask(A, cc, -c)
def addIc_mask(self, A, x, y, zmask, c):
if self.lptype == LinearProgType.H:
if x == y:
A.add_last_row(((1 << x) | zmask) - 1, c)
if zmask != 0:
A.add_last_row(zmask - 1, -c)
else:
A.add_last_row(((1 << x) | zmask) - 1, c)
A.add_last_row(((1 << y) | zmask) - 1, c)
A.add_last_row(((1 << x) | (1 << y) | zmask) - 1, -c)
if zmask != 0:
A.add_last_row(zmask - 1, -c)
elif self.lptype == LinearProgType.HC1BN or self.lptype == LinearProgType.HMIN:
if x == y:
self.addH_mask(A, (1 << x) | zmask, c)
if zmask != 0:
self.addH_mask(A, zmask, -c)
else:
self.addH_mask(A, (1 << x) | zmask, c)
self.addH_mask(A, (1 << y) | zmask, c)
self.addH_mask(A, (1 << x) | (1 << y) | zmask, -c)
if zmask != 0:
self.addH_mask(A, zmask, -c)
def addExpr(self, A, x):
for (a, c) in x.terms:
termType = a.get_type()
if termType == TermType.IC:
k = len(a.x)
for t in range(1 << k):
csgn = -1
mask = self.index.get_mask(a.z)
for i in range(k):
if (t & (1 << i)) != 0:
csgn = -csgn
mask |= self.index.get_mask(a.x[i])
if mask != 0:
self.addH_mask(A, mask, c * csgn)
elif termType == TermType.REAL or termType == TermType.REGION:
k = self.index.get_index(a.x[0].varlist[0])
if k >= 0:
self.addreal_id(A, k, c)
def addfcn(self, x):
for (a, c) in x.terms:
termType = a.get_type()
if termType == TermType.IC and len(a.x) == 1:
self.fcn_list.append((self.index.get_mask(a.x[0]), self.index.get_mask(a.z)))
def fcn_mask_maximize(self, zmask):
did = True
while did:
did = False
for cx, cz in self.fcn_list:
if cz | zmask == zmask and cx | zmask != zmask:
zmask |= cx
did = True
return zmask
def fcn_mask_minimize(self, zmask):
n = self.index.num_rv()
for i in range(n - 1, -1, -1):
if zmask & (1 << i):
if self.checkfcn_mask(1 << i, zmask - (1 << i)):
zmask -= 1 << i
return zmask
def checkfcn_mask(self, xmask, zmask):
if xmask < 0:
return False
if zmask | xmask == zmask:
return True
did = True
while did:
did = False
for cx, cz in self.fcn_list:
if cz | zmask == zmask and cx | zmask != zmask:
zmask |= cx
if zmask | xmask == zmask:
return True
did = True
return False
def checkfcn(self, x, z):
xmask = self.index.get_mask(x)
if xmask < 0:
return False
zmask = self.index.get_mask(z.inter(self.index.comprv))
return self.checkfcn_mask(xmask, zmask)
def get_bnet_fcn_region(self):
r = Region.universe()
if self.bnet is not None:
r = self.bnet.get_region()
for cx, cz in self.fcn_list:
r &= Expr.Hc(self.index.from_mask(cx), self.index.from_mask(cz)) == 0
return r
def addExpr_ge0(self, x):
if x.size() == 0:
return
self.Au.addrow(info = -x)
self.addExpr(self.Au, -x)
self.bu.append(0.0)
# print(str(x) + " " + str(x.get_meta("pf_note"))) # ???????????????
def addExpr_ge0_fcn(self, x):
if x.size() == 0:
return
if self.fcn_mode >= 1:
if x.isnonpos():
self.addfcn(x)
def addExpr_eq0(self, x):
if x.size() == 0:
return
self.Ae.addrow(info = x.copy())
self.addExpr(self.Ae, x)
self.be.append(0.0)
def addExpr_eq0_fcn(self, x):
if x.size() == 0:
return
if self.fcn_mode >= 1:
if x.isnonpos() or x.isnonneg():
self.addfcn(x)
def add_ent_ineq(self):
quantum = self.quantum
n = self.index.num_rv()
npow = (1 << n)
hge0 = self.hge0
hcge0 = self.hcge0
ige0 = self.ige0
icge0 = self.icge0
if not quantum and (not hge0 or not hcge0 or not ige0 or not icge0):
if ige0:
for xmask in range(1, 1 << n):
for ymask in range(xmask + 1, 1 << n):
self.addExpr_ge0(Expr.I(self.index.comprv.from_mask(xmask), self.index.comprv.from_mask(ymask)))
if hge0:
for xmask in range(1, 1 << n):
self.addExpr_ge0(Expr.H(self.index.comprv.from_mask(xmask)))
if hcge0:
for zmask in range(1, 1 << n):
for x in range(n):
if (zmask & (1 << x)) == 0:
self.addExpr_ge0(Expr.Hc(self.index.comprv[x], self.index.comprv.from_mask(zmask)))
return
if quantum:
for x in range(n):
zmask = 0
while zmask < npow:
wmask = npow - 1 - (1 << x) - zmask
if zmask > wmask:
break
self.Au.addrow(info = -Expr.Hc(self.index.comprv[x], self.index.comprv.from_mask(zmask))
- Expr.Hc(self.index.comprv[x], self.index.comprv.from_mask(wmask)))
self.addIc_mask(self.Au, x, x, zmask, -1.0)
self.addIc_mask(self.Au, x, x, wmask, -1.0)
self.bu.append(0.0)
zmask += 1
if (zmask & (1 << x)) != 0:
zmask += (1 << x)
else:
if self.dual_form:
for x in range(n):
zmask = 0
while zmask < npow:
self.Au.addrow(info = -Expr.Hc(self.index.comprv[x], self.index.comprv.from_mask(zmask)))
self.addIc_mask(self.Au, x, x, zmask, -1.0)
self.bu.append(0.0)
zmask += 1
if (zmask & (1 << x)) != 0:
zmask += (1 << x)
else:
for x in range(n):
self.Au.addrow(info = -Expr.Hc(self.index.comprv[x], self.index.comprv.from_mask(npow - 1 - (1 << x))))
self.addIc_mask(self.Au, x, x, npow - 1 - (1 << x), -1.0)
self.bu.append(0.0)
for x in range(n):
for y in range(x + 1, n):
zmask = 0
while zmask < npow:
self.Au.addrow(info = -Expr.Ic(self.index.comprv[x], self.index.comprv[y], self.index.comprv.from_mask(zmask)))
self.addIc_mask(self.Au, x, y, zmask, -1.0)
if self.lptype == LinearProgType.HC1BN or self.lptype == LinearProgType.HMIN:
self.Au.simplify_last_row()
if self.Au.last_row_isempty():
self.Au.poprow()
else:
self.bu.append(0.0)
else:
self.bu.append(0.0)
zmask += 1
if (zmask & (1 << x)) != 0:
zmask += (1 << x)
if (zmask & (1 << y)) != 0:
zmask += (1 << y)
def finish(self, skip_ent_ineq = False, skip_remove_empty = False, skip_solver = False):
ceps = PsiOpts.settings["eps"]
dual_form = self.dual_form
if not skip_ent_ineq:
self.add_ent_ineq()
PsiRec.num_lpprob += 1
#self.Au.simplify()
#self.Ae.simplify()
if self.lptype == LinearProgType.HC1BN or self.lptype == LinearProgType.HMIN:
self.Au.unique_row()
self.bu = [0.0] * len(self.Au.x)
if False:
k = self.index.get_index(IVar.one())
if k >= 0:
self.Ae.addrow()
self.addreal_id(self.Ae, k, 1.0)
self.be.append(1.0)
self.affine_present = True
k = self.index.get_index(IVar.eps())
if k >= 0:
self.Ae.addrow()
self.addreal_id(self.Ae, k, 1.0)
self.be.append(self.lp_eps)
self.eps_present = True
self.affine_present = True
k = self.index.get_index(IVar.inf())
if k >= 0:
self.Ae.addrow()
self.addreal_id(self.Ae, k, 1.0)
self.be.append(self.lp_ubound)
self.affine_present = True
if True:
id_one = self.index.get_index(IVar.one())
if id_one >= 0:
id_one += self.realshift
id_eps = self.index.get_index(IVar.eps())
if id_eps >= 0:
id_eps += self.realshift
id_inf = self.index.get_index(IVar.inf())
if id_inf >= 0:
id_inf += self.realshift
if id_one >= 0 or id_eps >= 0 or id_inf >= 0:
self.affine_present = True
if id_eps >= 0:
self.eps_present = True
if self.affine_present and not dual_form:
self.lp_bounded = True
if self.lp_bounded:
if self.index.num_rv() > 0:
self.Au.addrow(info = Expr.zero())
self.addH_mask(self.Au, (1 << self.index.num_rv()) - 1, 1.0)
#for i in range(self.realshift):
# self.Au.add_last_row(i, 1.0)
self.bu.append(self.lp_ubound)
for i in range(self.index.num_real()):
if Term.fromcomp(self.index.compreal[i]).isrealvar():
self.Au.addrow(info = Expr.zero())
self.addreal_id(self.Au, i, 1.0)
self.bu.append(self.lp_ubound)
self.Au.addrow(info = Expr.zero())
self.addreal_id(self.Au, i, -1.0)
self.bu.append(self.lp_ubound)
cols = self.Au.nonzero_cols()
coles = self.Ae.nonzero_cols()
cols = [a or b for a, b in zip(cols, coles)]
#print(self.Au.x)
if True:
self.constmap = {}
if id_one >= 0:
cols[id_one] = False
self.constmap[id_one] = 1.0
if id_eps >= 0:
cols[id_eps] = False
self.constmap[id_eps] = self.lp_eps
if id_inf >= 0:
cols[id_inf] = False
self.constmap[id_inf] = self.lp_ubound
if len(self.constmap):
self.bu = [b - a for a, b in zip(self.Au.sumrows(self.constmap, remove = True), self.bu)]
self.be = [b - a for a, b in zip(self.Ae.sumrows(self.constmap, remove = True), self.be)]
#print(self.Au.x)
if skip_remove_empty:
cols = [True] * len(cols)
self.xvarid = [0] * self.nvar
self.nxvar = 0
for i in range(self.nvar):
if cols[i]:
self.xvarid[i] = self.nxvar
self.nxvar += 1
else:
self.xvarid[i] = -1
self.Au.mapcol(self.xvarid)
self.Au.width = self.nxvar
self.Ae.mapcol(self.xvarid)
self.Ae.width = self.nxvar
# print(self.xvarid)
# print(self.Au.x)
# print(self.bu)
#print(self.Au.x)
if True:
for i in range(len(self.Au.x)):
if len(self.Au.x[i]) == 0:
if self.bu[i] < -ceps:
self.pinfeas = True
self.bu[i] = None
for i in range(len(self.Ae.x)):
if len(self.Ae.x[i]) == 0:
if abs(self.be[i]) > ceps:
self.pinfeas = True
self.be[i] = None
# self.Au.x = [a for a in self.Au.x if len(a)]
# self.Ae.x = [a for a in self.Ae.x if len(a)]
self.Au.remove_empty_rows()
self.Ae.remove_empty_rows()
self.bu = [a for a in self.bu if a is not None]
self.be = [a for a in self.be if a is not None]
#print(self.Au.x)
if not skip_solver:
self.calc_solver()
def apply_discover_bounds(self):
if self.solver.startswith("ortools."):
model = self.solver_param["model"]
xvar = self.solver_param["xvar"]
sum_bound_cons = self.solver_param["sum_bound_cons"]
if sum_bound_cons is not None:
if -1 in self.vdiscover_disabled:
sum_bound_cons.SetBounds(-model.infinity(), model.infinity())
else:
sum_bound_cons.SetBounds(-1, model.infinity())
for i, (j, b) in enumerate(zip(self.vdiscover, self.vdiscover_bounds)):
if i in self.vdiscover_disabled:
xvar[j].SetBounds(-model.infinity(), model.infinity())
else:
xvar[j].SetBounds(b[0], b[1])
else:
raise RuntimeError("Requires ortools solver")
def active_discover_bounds(self, v):
r = set()
if v is None:
return r
tol_coeff = 0.95
# ceps = PsiOpts.settings["eps_lp"]
if -1 not in self.vdiscover_disabled:
sum_bound_weights = self.solver_param["sum_bound_weights"]
cv = sum(v[j] * c for (j, c) in zip(self.vdiscover, sum_bound_weights))
if cv <= -tol_coeff:
r.add(-1)
for i, (j, b) in enumerate(zip(self.vdiscover, self.vdiscover_bounds)):
if i not in self.vdiscover_disabled:
if isinstance(b[0], (float, int)) and v[j] <= b[0] * tol_coeff:
r.add(i)
elif isinstance(b[1], (float, int)) and v[j] >= b[1] * tol_coeff:
r.add(i)
# print(v[j], b[0], b[1])
return r
def calc_solver(self):
verbose_discover = PsiOpts.settings.get("verbose_discover", False)
ceps = PsiOpts.settings["eps"]
dual_form = self.dual_form
self.dual_form_infeas = False
if self.solver == "scipy":
self.solver_param["Aus"] = self.Au.tolil()
self.solver_param["Aes"] = self.Ae.tolil()
elif self.solver.startswith("pulp."):
prob = None
xvar = None
if dual_form:
if self.dual_form_obj is not None:
self.dual_form_ncons = len(self.Au.x) + len(self.Ae.x) * 2
self.dual_form_weights = []
prob = pulp.LpProblem("lpentineq" + str(PsiRec.num_lpprob), pulp.LpMinimize)
xvar = pulp.LpVariable.dicts("x", [str(i) for i in range(self.dual_form_ncons)])
vexprs = [None] * (self.nxvar + 1)
for i, (cx, cb, csn, cri) in enumerate(itertools.chain(((tx, tb, -1, tri) for tx, tb, tri in zip(self.Au.x, self.bu, self.Au.rowinfo)),
((tx, tb, tsn, tri) for tx, tb, tri in zip(self.Ae.x, self.be, self.Ae.rowinfo) for tsn in [-1, 1]))):
prob += xvar[str(i)] >= 0
cweight = 1.0 + i * 0.2 / self.dual_form_ncons
if cri is not None:
cweight += 0.2 * len(cri.allcomp())
if cri is not None and cri.get_meta("dual_weight") is not None:
cweight *= cri.get_meta("dual_weight")
self.dual_form_weights.append(cweight)
for (j, c) in cx:
# print(str(i) + " " + str(j) + " " + str(c))
if vexprs[j] is None:
vexprs[j] = c * csn * xvar[str(i)]
else:
vexprs[j] += c * csn * xvar[str(i)]
if abs(cb) > ceps:
if vexprs[self.nxvar] is None:
vexprs[self.nxvar] = cb * csn * xvar[str(i)]
else:
vexprs[self.nxvar] += cb * csn * xvar[str(i)]
for j in range(self.nxvar):
if vexprs[j] is None:
if abs(self.dual_form_obj[j]) > ceps:
self.dual_form_infeas = True
continue
prob += vexprs[j] == self.dual_form_obj[j]
if vexprs[self.nxvar] is None:
if -self.dual_form_cutoff > ceps:
self.dual_form_infeas = True
else:
# print(vexprs[self.nxvar])
# print(self.dual_form_cutoff)
prob += vexprs[self.nxvar] >= -self.dual_form_cutoff
else:
prob = pulp.LpProblem("lpentineq" + str(PsiRec.num_lpprob), pulp.LpMinimize)
xvar = pulp.LpVariable.dicts("x", [str(i) for i in range(self.nxvar)])
for a, b in zip(self.Au.x, self.bu):
if len(a):
prob += pulp.LpConstraint(pulp.lpSum([xvar[str(j)] * c for (j, c) in a]),
sense = -1, rhs = b)
for a, b in zip(self.Ae.x, self.be):
if len(a):
prob += pulp.LpConstraint(pulp.lpSum([xvar[str(j)] * c for (j, c) in a]),
sense = 0, rhs = b)
if False:
for a, b in zip(self.Au.x, self.bu):
if len(a):
#print(" $ ".join([str((j, c)) for (j, c) in a]))
prob += pulp.LpConstraint(pulp.LpAffineExpression([(xvar[str(j)], c) for (j, c) in a]),
sense = -1, rhs = b)
for a, b in zip(self.Ae.x, self.be):
if len(a):
prob += pulp.LpConstraint(pulp.LpAffineExpression([(xvar[str(j)], c) for (j, c) in a]),
sense = 0, rhs = b)
if False:
for i in range(len(self.Au.x)):
cexpr = None
for (j, c) in self.Au.x[i]:
if cexpr is None:
cexpr = c * xvar[str(j)]
else:
cexpr += c * xvar[str(j)]
if cexpr is not None:
prob += cexpr <= self.bu[i]
for i in range(len(self.Ae.x)):
cexpr = None
for (j, c) in self.Ae.x[i]:
if cexpr is None:
cexpr = c * xvar[str(j)]
else:
cexpr += c * xvar[str(j)]
if cexpr is not None:
prob += cexpr == self.be[i]
#print(prob)
self.solver_param["prob"] = prob
self.solver_param["xvar"] = xvar
elif self.solver.startswith("pyomo."):
solver_opt = self.solver[self.solver.index(".") + 1 :]
opt = SolverFactory(solver_opt)
model = pyo.ConcreteModel()
if self.dual_enabled:
model.dual = pyo.Suffix(direction=pyo.Suffix.IMPORT)
if dual_form:
if self.dual_form_obj is not None:
self.dual_form_ncons = len(self.Au.x) + len(self.Ae.x) * 2
self.dual_form_weights = []
model.n = pyo.Param(default = self.dual_form_ncons)
model.x = pyo.Var(pyo.RangeSet(model.n), domain=pyo.Reals)
model.c = pyo.ConstraintList()
vexprs = [None] * (self.nxvar + 1)
for i, (cx, cb, csn, cri) in enumerate(itertools.chain(((tx, tb, -1, tri) for tx, tb, tri in zip(self.Au.x, self.bu, self.Au.rowinfo)),
((tx, tb, tsn, tri) for tx, tb, tri in zip(self.Ae.x, self.be, self.Ae.rowinfo) for tsn in [-1, 1]))):
model.c.add(model.x[i + 1] >= 0)
cweight = 1.0 + i * 0.2 / self.dual_form_ncons
if cri is not None:
cweight += 0.2 * len(cri.allcomp())
if cri is not None and cri.get_meta("dual_weight") is not None:
cweight *= cri.get_meta("dual_weight")
self.dual_form_weights.append(cweight)
for (j, c) in cx:
# print(str(i) + " " + str(j) + " " + str(c))
if vexprs[j] is None:
vexprs[j] = c * csn * model.x[i + 1]
else:
vexprs[j] += c * csn * model.x[i + 1]
if abs(cb) > ceps:
if vexprs[self.nxvar] is None:
vexprs[self.nxvar] = cb * csn * model.x[i + 1]
else:
vexprs[self.nxvar] += cb * csn * model.x[i + 1]
for j in range(self.nxvar):
if vexprs[j] is None:
if abs(self.dual_form_obj[j]) > ceps:
self.dual_form_infeas = True
continue
model.c.add(vexprs[j] == self.dual_form_obj[j])
if vexprs[self.nxvar] is None:
if -self.dual_form_cutoff > ceps:
self.dual_form_infeas = True
else:
# print(vexprs[self.nxvar])
# print(self.dual_form_cutoff)
model.c.add(vexprs[self.nxvar] >= -self.dual_form_cutoff)
else:
model.n = pyo.Param(default=self.nxvar)
model.x = pyo.Var(pyo.RangeSet(model.n), domain=pyo.Reals)
model.c = pyo.ConstraintList()
for i in range(len(self.Au.x)):
cexpr = None
for (j, c) in self.Au.x[i]:
if cexpr is None:
cexpr = c * model.x[j + 1]
else:
cexpr += c * model.x[j + 1]
if cexpr is not None:
model.c.add(cexpr <= self.bu[i])
for i in range(len(self.Ae.x)):
cexpr = None
for (j, c) in self.Ae.x[i]:
if cexpr is None:
cexpr = c * model.x[j + 1]
else:
cexpr += c * model.x[j + 1]
if cexpr is not None:
model.c.add(cexpr == self.be[i])
self.solver_param["opt"] = opt
self.solver_param["model"] = model
elif self.solver.startswith("ortools."):
solver_opt = self.solver[self.solver.index(".") + 1 :]
model = ortools.linear_solver.pywraplp.Solver.CreateSolver(solver_opt)
xvar = None
if not model:
raise RuntimeError("Fail to initialize solver " + self.solver)
# if self.dual_enabled:
# model.dual = pyo.Suffix(direction=pyo.Suffix.IMPORT)
if dual_form:
if self.dual_form_obj is not None:
is_dual_discover = (isinstance(self.dual_form_obj, str) and self.dual_form_obj == "discover")
if is_dual_discover:
n_dual_discover_exprs = len(self.dual_discover_exprs)
if self.dual_discover_l1:
n_dual_discover_exprs *= 2
self.dual_form_ncons = len(self.Au.x) + len(self.Ae.x) + n_dual_discover_exprs
self.dual_form_weights = []
n = self.dual_form_ncons
if self.dual_discover_l1:
xvar = [model.NumVar(0, model.infinity(), 'x' + str(i)) for i in range(n)]
else:
xvar = [model.NumVar(0, model.infinity(), 'x' + str(i)) if i < len(self.Au.x)
else model.NumVar(-model.infinity(), model.infinity(), 'x' + str(i)) for i in range(n)]
Aesigns = [-1]
else:
self.dual_form_ncons = len(self.Au.x) + len(self.Ae.x) * 2
self.dual_form_weights = []
n = self.dual_form_ncons
xvar = [model.NumVar(0, model.infinity(), 'x' + str(i)) for i in range(n)]
Aesigns = [-1, 1]
# model.c = pyo.ConstraintList()
vexprs = [None] * (self.nxvar + 1)
self.vdiscover = []
self.vdiscover_bounds = []
dual_discover_list = []
if is_dual_discover:
for expr in self.dual_discover_exprs:
t, t2 = self.get_vec(expr, sparse = True)
dual_discover_list.append((t.x[0], t2, 1, False))
if self.dual_discover_l1:
dual_discover_list.append(([(j, -c) for j, c in t.x[0]], -t2, 1, False))
# print("T " + str(t.x[0]))
for i, (cx, cb, csn, cri) in enumerate(itertools.chain(((tx, tb, -1, tri) for tx, tb, tri in zip(self.Au.x, self.bu, self.Au.rowinfo)),
((tx, tb, tsn, tri) for tx, tb, tri in zip(self.Ae.x, self.be, self.Ae.rowinfo) for tsn in Aesigns),
dual_discover_list)):
# model.Add(x[i] >= 0)
cweight = 1.0 + i * 0.2 / self.dual_form_ncons
if cri is not None and cri is not False:
cweight += 0.2 * len(cri.allcomp())
if cri is not None and cri is not False and cri.get_meta("dual_weight") is not None:
cweight *= cri.get_meta("dual_weight")
self.dual_form_weights.append(cweight)
for (j, c) in cx:
# print(str(i) + " " + str(j) + " " + str(c))
if vexprs[j] is None:
vexprs[j] = c * csn * xvar[i]
else:
vexprs[j] += c * csn * xvar[i]
if abs(cb) > ceps:
if vexprs[self.nxvar] is None:
vexprs[self.nxvar] = cb * csn * xvar[i]
else:
vexprs[self.nxvar] += cb * csn * xvar[i]
if cri is False:
self.vdiscover.append(i)
self.vdiscover_bounds.append([-model.infinity(), model.infinity()])
# discover_i = cri.get_meta("dual_discover")
# if discover_i is not None:
# while len(self.vdiscover) <= discover_i:
# self.vdiscover.append(-1)
# self.vdiscover[discover_i] = i
if is_dual_discover:
for j in range(self.nxvar + 1):
if vexprs[j] is not None:
model.Add(vexprs[j] == 0)
if self.vdiscover:
sum_bound = None
default_bound = 10000
sum_bound_weights = []
if self.dual_discover_l1:
for x2 in self.vdiscover:
if sum_bound is None:
sum_bound = xvar[x2] * -1
else:
sum_bound += xvar[x2] * -1
else:
for i2, (x2, expr) in enumerate(zip(self.vdiscover, self.dual_discover_exprs)):
cw = 0.0
if self.dual_discover_init_model is not None:
cw = float(self.dual_discover_init_model[expr])
else:
cw = expr.rv_weight()
if verbose_discover:
print("BOUND " + str(expr) + " : " + str(cw))
# print(expr)
# print(cw)
# print()
sum_bound_weights.append(cw)
if cw != 0:
if sum_bound is None:
sum_bound = xvar[x2] * cw
else:
sum_bound += xvar[x2] * cw
if cw >= 0:
self.vdiscover_bounds[i2][1] = default_bound
# model.Add(xvar[x2] <= default_bound)
if cw <= 0:
self.vdiscover_bounds[i2][0] = -default_bound
# model.Add(xvar[x2] >= -default_bound)
# if cw <= 0:
# model.Add(xvar[x2] >= -default_bound)
# model.Add(xvar[x2] <= default_bound)
# else:
# # model.Add(xvar[x2] >= -default_bound)
# model.Add(xvar[x2] <= default_bound)
# if sum_bound is None:
# sum_bound = xvar[x2] * cw
# else:
# sum_bound += xvar[x2] * cw
# cw = 4.0 ** len(expr.allcomprv())
# model.Add(xvar[x2] >= -cw)
# model.Add(xvar[x2] <= cw)
self.solver_param["model"] = model
self.solver_param["xvar"] = xvar
self.solver_param["sum_bound_weights"] = sum_bound_weights
self.solver_param["sum_bound_cons"] = None
if sum_bound is not None:
# sum_bound_cons = model.Add(sum_bound >= -1)
# self.solver_param["sum_bound_cons"] = sum_bound_cons
sum_bound_cons = model.Constraint(-100, model.infinity())
for x2, cw in zip(self.vdiscover, sum_bound_weights):
if cw != 0:
sum_bound_cons.SetCoefficient(xvar[x2], cw)
self.solver_param["sum_bound_cons"] = sum_bound_cons
self.apply_discover_bounds()
# model.Add(sum(xvar[x2] for x2 in self.vdiscover) <= 1)
# model.Add(sum(xvar[x2] for x2 in self.vdiscover) >= -1)
# for x2 in self.vdiscover:
# model.Add(xvar[x2] >= -1)
# model.Add(sum(xvar[x2] for x2 in self.vdiscover) <= 1)
else:
for j in range(self.nxvar):
if vexprs[j] is None:
if abs(self.dual_form_obj[j]) > ceps:
self.dual_form_infeas = True
continue
model.Add(vexprs[j] == self.dual_form_obj[j])
if vexprs[self.nxvar] is None:
if -self.dual_form_cutoff > ceps:
self.dual_form_infeas = True
else:
# print(vexprs[self.nxvar])
# print(self.dual_form_cutoff)
model.Add(vexprs[self.nxvar] >= -self.dual_form_cutoff)
else:
# model.n = pyo.Param(default=self.nxvar)
# model.x = pyo.Var(pyo.RangeSet(model.n), domain=pyo.Reals)
# model.c = pyo.ConstraintList()
n = self.nxvar
xvar = [model.NumVar(-model.infinity(), model.infinity(), 'x' + str(i)) for i in range(n)]
for i in range(len(self.Au.x)):
cexpr = None
for (j, c) in self.Au.x[i]:
if cexpr is None:
cexpr = c * xvar[j]
else:
cexpr += c * xvar[j]
if cexpr is not None:
model.Add(cexpr <= self.bu[i])
for i in range(len(self.Ae.x)):
cexpr = None
for (j, c) in self.Ae.x[i]:
if cexpr is None:
cexpr = c * xvar[j]
else:
cexpr += c * xvar[j]
if cexpr is not None:
model.Add(cexpr == self.be[i])
self.solver_param["model"] = model
self.solver_param["xvar"] = xvar
def id_toexpr(self):
n = self.index.num_rv()
xvarinv = [0] * self.nxvar
for i in range(self.nvar):
if self.xvarid[i] >= 0:
xvarinv[self.xvarid[i]] = i
cellposinv = list(range(1, self.realshift + 1))
if self.lptype == LinearProgType.HC1BN or self.lptype == LinearProgType.HMIN:
for mask in range((1 << n) - 1, 0, -1):
if self.cellpos[mask] >= 0:
cellposinv[self.cellpos[mask]] = mask
def rf(j2):
if j2 >= self.nxvar:
if j2 == self.nxvar:
return (Expr.const(1.0), 0)
return None
j = xvarinv[j2]
if j >= self.realshift:
return (Expr.real(self.index.compreal.varlist[j - self.realshift].name), 0)
mask = cellposinv[j]
term = None
if self.lptype == LinearProgType.H or self.lptype == LinearProgType.HMIN:
term = Term.H(Comp.empty())
for i in range(n):
if mask & (1 << i) != 0:
term.x[0].varlist.append(self.index.comprv.varlist[i])
elif self.lptype == LinearProgType.HC1BN:
term = Term.Hc(Comp.empty(), Comp.empty())
for i in range(n):
if mask & (1 << i) != 0:
term.z.varlist.append(self.index.comprv.varlist[i])
term.x[0].varlist.append(term.z.varlist.pop())
return (Expr.fromterm(term), mask)
return rf
def get_var_exprs(self):
idt = self.id_toexpr()
r = [idt(i)[0] for i in range(self.nxvar)]
return ExprArray(r)
def row_toexpr(self):
idt = self.id_toexpr()
def rf(x):
if isinstance(x, SparseMat):
x = x.x[0]
expr = Expr.zero()
if len(x) == 0:
return expr
if isinstance(x[0], tuple):
for (j2, c) in x:
te, mask = idt(j2)
#expr.terms.append((te, c))
expr += te * c
else:
ceps = PsiOpts.settings["eps"]
for i in range(len(x)):
if abs(x[i]) > ceps:
te, mask = idt(i)
expr += te * x[i]
return expr
return rf
def get_region(self, toreal = None, toreal_only = False, A = None, skip_simplify = False):
if toreal is None:
toreal = Comp.empty()
torealmask = self.index.get_mask(toreal)
idt = self.id_toexpr()
if A is None:
A = ([(a, 1, b) for a, b in zip(self.Au.x, self.bu)] +
[(a, 0, b) for a, b in zip(self.Ae.x, self.be)])
else:
A = [(a, 1, 0.0) for a in A.x]
r = Region.universe()
for x, sn, b in A:
expr = Expr.zero()
toreal_present = False
for (j2, c) in x:
te, mask = idt(j2)
termreal = (mask & torealmask) != 0
toreal_present |= termreal
if termreal:
expr += Expr.real("R_" + str(te)) * c
else:
expr += te * c
if toreal_present or not toreal_only:
if not skip_simplify:
expr.simplify_quick()
if sn == 1:
r &= (expr <= Expr.const(b))
else:
r &= (expr == Expr.const(b))
return r
def get_dual_region(self, mul_coeff = True, omit_trivial = False, tosum = None):
if self.dual_e is None or self.dual_u is None:
return None
rowt = self.row_toexpr()
ceps = PsiOpts.settings["eps"]
r = Region.universe()
for x, xb, d in zip(self.Au.x, self.bu, self.dual_u):
if abs(d) > ceps:
x2 = (xb - rowt(x)).simplified_quick()
if omit_trivial and x2.isnonneg():
continue
if mul_coeff:
r.exprs_ge.append(x2 * -d)
else:
r.exprs_ge.append(x2)
if tosum is not None:
tosum += x2 * -d
for x, xb, d in zip(self.Ae.x, self.be, self.dual_e):
if abs(d) > ceps:
x2 = (xb - rowt(x)).simplified_quick()
if mul_coeff:
r.exprs_eq.append(x2 * -d)
else:
r.exprs_eq.append(x2)
if tosum is not None:
tosum += x2 * -d
return r
def get_dual_sum_region(self, x):
r = self.get_dual_region()
with PsiOpts(proof_enabled = False):
rsum = sum(r.exprs_ge + r.exprs_eq, Expr.zero())
rdiff = (x - rsum).simplified_quick()
if not rdiff.iszero():
rowt = self.row_toexpr()
for t in rdiff:
v = self.get_vec(t)[0]
if v is not None:
tv = (t - rowt(v)).simplified_quick()
if not tv.iszero():
r.exprs_eq.append(tv)
return r
def get_dual_sum_meta(self, rowt = None, meta = "sim"):
if self.dual_e is None or self.dual_u is None:
return None
ceps = PsiOpts.settings["eps"]
bnet = None
if rowt is None:
rowt = self.row_toexpr()
r = Expr.zero()
# print(self.Au.x)
# print(self.Au.rowinfo)
# print(self.Ae.x)
# print(self.Ae.rowinfo)
for x, ri, xb, d in itertools.chain(zip(self.Au.x, self.Au.rowinfo, self.bu, self.dual_u), zip(self.Ae.x, self.Ae.rowinfo, self.be, self.dual_e)):
if abs(d) > ceps:
# print(str(ri) + " " + str(d) + " " + str(ri.get_meta(meta)))
if ri.get_meta(meta) is not None:
r += ri * iutil.float_snap(d, PsiOpts.settings["max_denom_lp"], force = True)
return r
def copy_empty_nobnet(self):
if self.lptype == LinearProgType.HC1BN or self.lptype == LinearProgType.HMIN:
alt_prog = LinearProg(self.index.copy(), self.lptype, BayesNet())
alt_prog.finish(skip_ent_ineq = True, skip_remove_empty = True, skip_solver = True)
return alt_prog
return None
def col_steps(self, cri, rowt):
cri2 = rowt(self.get_vec(cri, sparse = True, cat = True))
if (cri - cri2).simplified_quick().iszero():
return [cri]
# print((cri, cri2))
if not (len(cri) == 1 and cri.terms[0][0].ishc()
and len(cri2) == 1 and cri2.terms[0][0].ishc()
and cri.terms[0][0].x == cri2.terms[0][0].x):
return [cri, cri2]
# print("A")
t1 = cri.terms[0][0]
t2 = cri2.terms[0][0]
if t1.z.super_of(t2.z) or t2.z.super_of(t1.z):
return [cri, cri2]
z3 = t1.z + t2.z
cri3 = Expr.Hc(t1.x[0], z3)
# print((cri, cri2, cri3))
if self.get_vec(cri2 - cri3, sparse = True, cat = True).iszero():
return [cri, cri3, cri2]
return [cri, cri2]
def get_dual_terms(self, expr, rowt = None, alt_prog = None, ri_simplify = True, nshuffle = 3, deficit_separate = True):
verbose = PsiOpts.settings.get("verbose_proof_step", False)
if self.dual_e is None or self.dual_u is None:
return None
ceps = PsiOpts.settings["eps"]
bnet = None
if rowt is None:
rowt = self.row_toexpr()
alt_rowt = None
deficit = None
if alt_prog is not None:
alt_rowt = alt_prog.row_toexpr()
deficit_type = "#TRI"
if deficit_separate:
deficit_type = "#DTRI"
r = []
tmpi = 0
for tmpi, (x, ri, xb, d, iseq) in enumerate(itertools.chain(zip(self.Au.x, self.Au.rowinfo, self.bu, self.dual_u, [False] * len(self.bu)),
zip(self.Ae.x, self.Ae.rowinfo, self.be, self.dual_e, [True] * len(self.be)))):
if abs(d) > ceps:
# r.append((SparseMat.from_row(x, self.nxvar).tonumpyarray_row(),
# rowt(x).simplified_break_hc(bnet = bnet), d))
crow = SparseMat.from_row(x, self.nxvar + 1).tonumpyarray_row()
crow[self.nxvar] = -xb
# ri = ri.simplified_break_hc(bnet = bnet)
if ri_simplify:
ri = ri.simplified_break_hc(bnet = self.bnet)
d = iutil.float_snap(d, PsiOpts.settings["max_denom_lp"], force = True)
iseq = iseq or (abs(xb) < ceps and (ri * d).isnonpos())
if alt_prog is not None:
crow2 = numpy.array(alt_prog.get_vec(ri, cat = True))
ri3 = rowt(crow)
crow3 = numpy.array(alt_prog.get_vec(ri3, cat = True))
# print(crow2)
# print(crow3)
# print(ri)
# print(ri3)
# r.append((crow3 - crow2, (ri3 - ri).simplified(), d, True))
crow = crow2
if deficit is None:
deficit = (crow3 - crow2) * d
else:
deficit += (crow3 - crow2) * d
if verbose:
arow = alt_rowt(crow3 - crow2)
if not arow.iszero():
print("Deficit added : " + str(ri) + " " + str(ri3) + " : " + str(arow))
curtype = ""
rii = ri.get_meta("pf_note")
if isinstance(rii, list):
if len(rii):
curtype = str(rii[0])
elif (ri * d).isnonneg():
curtype = "#TRI"
elif (ri * d).isnonpos():
if self.noskip:
# curtype = "#N" + str(tmpi)
curtype = "#CTRI"
else:
curtype = "#TRI"
if any(abs(x) > ceps for x in crow):
r.append((crow, ri, d, iseq, curtype))
# print(str(r[-1][1]) + " * " + str(d) + " " + str(r[-1][1].get_meta("pf_note"))) # ???????????????
if alt_prog is not None:
exprrow = numpy.array(alt_prog.get_vec(expr, cat = True))
expr2 = rowt(self.get_vec(expr, cat = True))
exprrow2 = numpy.array(alt_prog.get_vec(expr2, cat = True))
if deficit is None:
deficit = exprrow - exprrow2
else:
deficit += exprrow - exprrow2
if deficit is not None and any(abs(x) > ceps for x in deficit):
# for i in range(len(deficit) - 1, -1, -1):
for i in range(len(deficit)):
x = deficit[i]
if abs(x) <= ceps:
continue
cri = alt_rowt([(i, 1.0)])
steps = self.col_steps(cri, rowt)
if len(steps) <= 1:
if verbose:
print("Deficit UNRESOLVED: " + str(cri)) #+ " : " + str(alt_rowt(deficit)))
print(" remainder: " + str(alt_rowt(deficit)))
continue
stepsrow = [numpy.array(alt_prog.get_vec(s, sparse = False, cat = True)) for s in steps]
for j in range(len(steps) - 1):
r.append((stepsrow[j] - stepsrow[j + 1], (steps[j] - steps[j + 1]).simplified_quick(), x, True, deficit_type))
if verbose:
print("Deficit resolved " + str(j) + ": " + str(steps[j]) + " " + str(steps[j + 1]) + " : " + str((steps[j] - steps[j + 1]).simplified_quick()))
deficit += (stepsrow[-1] - stepsrow[0]) * x
if verbose:
print(" remainder: " + str(alt_rowt(deficit)))
# cri = alt_rowt([(i, x)])
# cri2 = rowt(self.get_vec(cri, sparse = True, cat = True))
# crow2 = alt_prog.get_vec(cri2, sparse = True, cat = True)
# if len(crow2.x[0]) == 1 and crow2.x[0][0][0] == i:
# if verbose:
# print("Deficit UNRESOLVED: " + str(i) + " " + str(cri) + " " + str(cri2) #+ " : " + str(alt_rowt(deficit)))
# continue
# crow2 = crow2.tonumpyarray_row()
# crow2[i] -= x
# r.append((crow2, -(cri - cri2).simplified(), -1.0, True, "#TRI"))
# deficit += crow2
# if verbose:
# print("Deficit resolved : " + str(i) + " " + str(cri) + " " + str(cri2) #+ " : " + str(alt_rowt(deficit)))
if deficit is not None and any(abs(x) > ceps for x in deficit):
r.append((-deficit, -alt_rowt(deficit).simplified_quick(), -1.0, True, deficit_type))
if verbose:
print("Deficit RESIDUAL : " + str(alt_rowt(deficit).simplified_quick()))
if alt_prog is not None:
rowt = alt_rowt
ro = r
ropt = None
rnd = random.Random(123)
for it in range(nshuffle):
r = iutil.copy(ro)
if it:
rnd.shuffle(r)
did = True
while did:
did = False
for i in range(len(r) - 1, -1, -1):
for j in range(i):
if abs(abs(r[i][2]) - abs(r[j][2])) > ceps:
continue
isn = 1
if r[i][2] * r[j][2] < 0:
isn = -1
if r[i][4] == "#CTRI":
if r[j][4] != "#TRI":
continue
elif r[j][4] == "#CTRI":
if r[i][4] != "#TRI":
continue
else:
if r[i][4] == "" or r[j][4] == "" or r[i][4] != r[j][4]:
continue
rii = r[i][1].get_meta("pf_note")
rij = r[j][1].get_meta("pf_note")
# rii_first = ""
# if isinstance(rii, list):
# rii_first = str(rii[0])
# elif (r[i][1] * r[i][2]).isnonneg():
# rii_first = "#SHA"
# rij_first = ""
# if isinstance(rij, list):
# rij_first = str(rij[0])
# elif (r[j][1] * r[j][2]).isnonneg():
# rij_first = "#SHA"
# print(str(r[i][1]) + " " + rii_first + " " + str(r[j][1]) + " " + rij_first)
# if rii_first != rij_first:
# continue
t = (r[j][1] + r[i][1] * isn).simplified_break_hc(bnet = bnet)
sumx = r[j][0] + r[i][0] * isn
t2 = rowt(sumx).simplified_break_hc(bnet = bnet)
if (len(t2), t2.complexity()) < (len(t), t.complexity()):
t = t2
# if t.complexity() < (r[i][1] + r[j][1]).complexity():
# if t.len_iccount() <= max(r[i][1].len_iccount(), r[j][1].len_iccount()):
if len(t) <= max(len(r[i][1]), len(r[j][1])):
# print(str(rii) + " " + str(rij))
if rii is not None or rij is not None:
t.add_meta("pf_note", rii or rij)
# print(t)
r[j] = (sumx, t, r[j][2], r[i][3] and r[j][3], r[j][4])
r.pop(i)
did = True
break
if ropt is None or len(r) < len(ropt):
ropt = r
r = ropt
# print([(x * d, t * d, iseq) for x, t, d, iseq in r])
# Split IC3
for i in range(len(r)):
if len(r[i][1]) == 1 and r[i][1].terms[0][0].isic3():
x = r[i][1].terms[0][0].x
lenx = len(x)
x = x + x
z = r[i][1].terms[0][0].z
coeff = r[i][1].terms[0][1] * r[i][2]
for k in range(lenx):
e1 = Expr.Ic(x[k + 1], x[k + 2], z).simplified_quick()
e2 = Expr.Ic(x[k + 1], x[k + 2], z + x[k]).simplified_quick()
e1eq = False
e2eq = False
if coeff > 0:
# if self.get_vec(e2, sparse = True, cat = True).iszero():
if ((self.bnet is not None and self.bnet.check_ic(e2))
or self.get_vec(e2, sparse = True, cat = True).iszero()):
e2eq = True
else:
# if self.get_vec(e1, sparse = True, cat = True).iszero():
if ((self.bnet is not None and self.bnet.check_ic(e1))
or self.get_vec(e1, sparse = True, cat = True).iszero()):
e1eq = True
if e1eq or e2eq:
r[i] = (numpy.array(alt_prog.get_vec(e1, cat = True)), e1, coeff, e1eq, "#TRI")
r.append((numpy.array(alt_prog.get_vec(e2, cat = True)), e2, -coeff, e2eq, "#TRI"))
break
return [(x * d, t * d, iseq, curtype) for x, t, d, iseq, curtype in r]
def get_dual_region_terms(self, expr, skip_trivial = True, ri_simplify = True, nshuffle = 3, compress = True):
if self.dual_e is None or self.dual_u is None:
return None
# print(expr)
rowt = self.row_toexpr()
ineqs = self.get_dual_terms(expr, rowt = rowt, ri_simplify = ri_simplify, nshuffle = nshuffle)
# print(ineqs)
r = Region.universe()
for i, (x, y, iseq, curtype) in enumerate(ineqs):
if skip_trivial and curtype == "#TRI":
continue
if iseq:
r.exprs_eq.append(y)
else:
r.exprs_ge.append(y)
# print(r)
return r
def get_dual_steps(self, expr, prefer_short = True, simplify_exhaust = False, simplify_prog = True, chain = False, step_optimize = True, alt_prog = None, ri_simplify = True, nshuffle = 3, compress = True, deficit_separate = True):
verbose = PsiOpts.settings.get("verbose_proof_step", False)
if self.dual_e is None or self.dual_u is None:
return None
# print(expr)
rowt = self.row_toexpr()
ceps = PsiOpts.settings["eps"]
ineqs = self.get_dual_terms(expr, rowt = rowt, alt_prog = alt_prog, ri_simplify = ri_simplify, nshuffle = nshuffle, deficit_separate = deficit_separate)
# print(ineqs)
veclen = self.nxvar + 1
get_vec = self.get_vec
bnet = self.bnet
if alt_prog is not None:
veclen = alt_prog.nxvar + 1
get_vec = alt_prog.get_vec
rowt = alt_prog.row_toexpr()
bnet = None
# for x, d in zip(self.Au.x, self.dual_u):
# if abs(d) > ceps:
# ineqs.append(SparseMat.from_row(x, self.nxvar).tonumpyarray_row() * d)
# for x, d in zip(self.Ae.x, self.dual_e):
# if abs(d) > ceps:
# ineqs.append(SparseMat.from_row(x, self.nxvar).tonumpyarray_row() * d)
prog = None
if simplify_prog and self.bnet is not None:
prog = self.get_bnet_fcn_region().init_simplified_prog(self.index.copy())
# for x, xin in zip(prog.Ae.x, prog.Ae.rowinfo):
# print(str(xin) + " " + str(x))
fsum = numpy.zeros(veclen)
for x, y, iseq, curtype in ineqs:
fsum += x
fsum = SparseMat.from_dense_row(fsum)
# print(rowt(fsum).simplified())
csum = numpy.zeros(veclen)
csum = SparseMat.from_dense_row(csum)
csumt = Expr.zero()
ineqs = [(SparseMat.from_dense_row(x), y, iseq, curtype) for x, y, iseq, curtype in ineqs]
sn = 1
r0 = []
r1 = []
riseq = []
rcurtype = []
left_realvars = None
if chain:
expr0, expr, eqnstr = expr.split_present(">=", lhsvar = "real")
# print(expr0)
# print(expr)
# print(eqnstr)
gv = get_vec(expr0)
csum = numpy.array(list(gv[0]) + [gv[1]])
csum = SparseMat.from_dense_row(csum)
if eqnstr == ">=":
sn = -1
fsum = fsum * sn + csum
left_realvars = expr0.allcomprealvar()
# print(rowt(fsum))
r0.append(None)
r1.append(expr0)
riseq.append(False)
rcurtype.append("#START")
exsum_gv = get_vec(expr)
exsum = numpy.array(list(exsum_gv[0]) + [exsum_gv[1]])
exsum = SparseMat.from_dense_row(exsum)
while ineqs:
mini = -1
mint = None
minc = 1000000000000000000
for i, (x, y, iseq, curtype) in enumerate(ineqs):
t = None
if prog is not None:
t = (csumt + y * sn).simplified_quick()
else:
# t = rowt(csum + x * sn).simplified_break_hc(bnet = self.bnet)
# t2 = (rowt(csum - exsum + x * sn) + expr).simplified_break_hc(bnet = self.bnet)
t = rowt(csum + x * sn).simplified_quick(bnet = bnet)
t2 = (rowt(csum - exsum + x * sn) + expr).simplified_quick(bnet = bnet)
if t2.complexity() < t.complexity():
t = t2
if simplify_exhaust:
t.simplify_break_hc()
tc = t.complexity()
if chain:
tc += len(t.allcomprealvar()) * 10000
# if not y.ispresent(left_realvars):
# tc += 1000000
else:
if not y.ispresent("realvar"):
tc += 1000000
if tc < minc:
minc = tc
mini = i
mint = t
if not step_optimize:
break
if prog is not None:
mint = mint.simplified_prog(prog = prog)
mint.sort()
csumt = mint
# r0.append(rowt(ineqs[mini]).simplified_break_hc(bnet = bnet))
r0.append(ineqs[mini][1])
csum += ineqs[mini][0] * sn
r1.append(mint)
riseq.append(ineqs[mini][2])
rcurtype.append(ineqs[mini][3])
ineqs.pop(mini)
# print(r0)
# print(r1)
# print(riseq)
if compress:
i = 0
while i < len(r0) - 1:
maxj = i
while maxj + 1 < len(r0) and rcurtype[maxj + 1] == rcurtype[i]:
maxj += 1
for j in range(maxj, i + 1, -1):
cdiff = ((r1[j] - r1[i]) * sn).simplified_quick()
if len(cdiff) == 1 and cdiff.terms[0][0].isicle2():
r1[i + 1] = r1[j]
r0[i + 1] = cdiff
riseq[i + 1] = all(riseq[i + 1 : j + 1])
del r0[i + 2 : j + 1]
del r1[i + 2 : j + 1]
del riseq[i + 2 : j + 1]
del rcurtype[i + 2 : j + 1]
break
i += 1
for i in range(len(r1) - 1):
r1[i + 1].make_similar(r1[i])
return (r0, r1, riseq, sn)
def write_pf(self, x):
verbose = PsiOpts.settings.get("verbose_proof_step", False)
write_pf_repeat_claim = PsiOpts.settings.get("proof_repeat_implicant", False)
include_note = PsiOpts.settings.get("proof_note", False)
note_skip_trivial = PsiOpts.settings.get("proof_note_skip_trivial", False)
pf = None
if PsiOpts.settings["proof_step_dualsum"]:
chain = PsiOpts.settings["proof_step_chain"]
alt_prog = None
if PsiOpts.settings["proof_step_bayesnet"]:
alt_prog = self.copy_empty_nobnet()
dualsteps = self.get_dual_steps(x, prefer_short = PsiOpts.settings["proof_step_dualsum_short"],
simplify_exhaust = PsiOpts.settings["proof_step_dualsum_exhaust"],
simplify_prog = PsiOpts.settings["proof_step_dualsum_prog"],
chain = chain,
step_optimize = PsiOpts.settings["proof_step_optimize"],
alt_prog = alt_prog,
ri_simplify = PsiOpts.settings["proof_step_term_simplify"],
nshuffle = PsiOpts.settings["proof_step_term_nshuffle"],
compress = PsiOpts.settings["proof_step_compress"],
deficit_separate = PsiOpts.settings["proof_deficit_separate"])
if dualsteps is None:
return
cadds, csums, ciseqs, csn = dualsteps
if len(cadds) == 0:
return
if not PsiOpts.settings["proof_noskip"] and len(cadds) <= 1:
return
pf = None
if chain:
ineqchain = []
for i, (cadd, csum, ciseq) in enumerate(zip(cadds, csums, ciseqs)):
cclaim = []
pf_note = None
# if cadd is not None:
# print(str(cadd)+ " " + str(cadd.get_meta("pf_note")))
if include_note and cadd is not None:
if not note_skip_trivial or ciseq or not cadd.isnonneg():
cclaim += ["(", "since", " "]
pf_note = cadd.get_meta("pf_note")
if pf_note is not None:
if isinstance(pf_note, list):
cclaim += pf_note
else:
cclaim += [pf_note]
cclaim += [":", " "]
if ciseq:
cclaim.append((cadd.copy() == 0).remove_meta("pf_note"))
else:
cclaim.append((cadd.copy() >= 0).remove_meta("pf_note"))
cclaim += [")"]
ceqnstr = ""
if ciseq:
ceqnstr = "="
elif i:
ceqnstr = "<=" if csn > 0 else ">="
ineqchain.append([csum, ceqnstr, cclaim])
pf = ProofObj.from_region(ineqchain, c = "Steps: ")
if PsiOpts.settings["proof_note_color"] is not None:
pf.meta["note_color"] = PsiOpts.settings["proof_note_color"]
if PsiOpts.settings["proof_note_newline"] is not None:
pf.meta["note_newline"] = PsiOpts.settings["proof_note_newline"]
else:
rt = Region.universe()
for t in cadds:
rt &= t >= 0
if not write_pf_repeat_claim or rt.isuniverse():
pf = ProofObj.from_region(x >= 0, c = "Claim: ")
else:
pf = ProofObj.from_region(("implies", rt, x >= 0), c = "Claim: ")
for i, (cadd, csum) in enumerate(zip(cadds, csums)):
cclaim = []
if i == 0:
cclaim += ["Have: ", cadd >= 0]
else:
cclaim += ["Add: ", cadd >= 0]
pf_note = None
if include_note:
pf_note = cadd.get_meta("pf_note")
if pf_note is not None:
cclaim += [" ", "("]
if isinstance(pf_note, list):
cclaim += pf_note
else:
cclaim += [pf_note]
cclaim += [")"]
cadd = None
if chain:
cadd = csum
else:
cadd = csum >= 0
pf += ProofObj.from_region(cadd, c = cclaim)
else:
r = self.get_dual_sum_region(x)
if r is None:
return
pf = ProofObj.from_region(r, c = ["Duals for ", x >= 0])
PsiOpts.set_setting(proof_add = pf)
def write_pf_old(self, x):
write_pf_repeat_claim = PsiOpts.settings.get("proof_repeat_implicant", False)
r = self.get_dual_sum_region(x)
if r is None:
return
xstr = str(x) + " >= 0"
pf = None
if PsiOpts.settings["proof_step_dualsum"]:
rt = r.removed_trivial()
pf = None
if not write_pf_repeat_claim or rt.isuniverse():
pf = ProofObj.from_region(x >= 0, c = "Claim: ")
else:
pf = ProofObj.from_region(("implies", rt, x >= 0), c = "Claim: ")
cadds, csums = r.get_sum_seq(prefer_short = PsiOpts.settings["proof_step_dualsum_short"], simplify_exhaust = PsiOpts.settings["proof_step_dualsum_exhaust"], bnet = self.bnet, target = x)
for i, (cadd, csum) in enumerate(zip(cadds, csums)):
if i == 0:
pf += ProofObj.from_region(csum >= 0, c = "Have:")
else:
pf += ProofObj.from_region(csum >= 0, c = ["Add: ", cadd >= 0])
# self.trytry()
# r2 = Region.universe()
# cur = Expr.zero()
# for x in r.exprs_ge:
# cur = (cur + x).simplified()
# r2.exprs_ge.append(cur)
# for x in r.exprs_eq:
# cur = (cur + x).simplified()
# r2.exprs_ge.append(cur)
# pf += ProofObj.from_region(r2, c = "Steps for " + xstr)
else:
pf = ProofObj.from_region(r, c = ["Duals for ", x >= 0])
PsiOpts.set_setting(proof_add = pf)
def get_extreme_rays_vec(self, A = None):
ma = None
cn = 0
if A is None:
cn = self.Au.width
ma = self.Ae.tonumpyarray()
ma = numpy.vstack((numpy.zeros(cn), self.Au.tonumpyarray(), ma, -ma))
else:
cn = A.width
ma = numpy.vstack((numpy.zeros(cn), A.tonumpyarray()))
#print(ma)
hull = scipy.spatial.ConvexHull(ma)
#print("ConvexHull finished")
r = []
#tone = numpy.array(self.get_vec(Expr.H(self.index.comprv)))
rset = set()
ceps = PsiOpts.settings["eps"]
for i in range(len(hull.simplices)):
if abs(hull.equations[i,-1]) > ceps:
continue
t = hull.equations[i,:-1]
vv = max(abs(t))
if vv > ceps:
t = t / vv
ts = ",".join(iutil.float_tostr(x) for x in t)
if ts not in rset:
rset.add(ts)
r.append(t)
return r
def get_region_elim_rays(self, aux = None, A = None, skip_simplify = False):
if aux is None:
aux = Comp.empty()
ceps = PsiOpts.settings["eps"]
auxmask = self.index.get_mask(aux)
#print("Before get_extreme_rays_vec")
vs = self.get_extreme_rays_vec(A)
#print("After get_extreme_rays_vec")
idt = self.id_toexpr()
var_expr = [None] * self.nxvar
var_id = [-1] * self.nxvar
var_id_inv = [-1] * self.nxvar
nleft = 0
for j in range(self.nxvar):
var_expr[j], mask = idt(j)
if mask & auxmask == 0:
var_id[j] = nleft
var_id_inv[nleft] = j
nleft += 1
vset = set()
vset.add(",".join(["0"] * nleft))
ma = numpy.zeros((1, nleft))
for v in vs:
nvv = 0
vv = numpy.zeros(nleft)
for i in range(len(v)):
if var_id[i] >= 0:
vv[var_id[i]] = v[i]
nvv += 1
if nvv > 0:
ts = ",".join(iutil.float_tostr(x) for x in vv)
#print("ts = " + ts)
if ts not in vset:
vset.add(ts)
ma = numpy.vstack((ma, vv))
eig, ev = numpy.linalg.eig(ma.T.dot(ma))
#print(nleft)
#print(ma)
#print(ma.T.dot(ma))
#print(eig)
#print(ev)
ev0 = numpy.zeros((nleft, 0))
ev1 = numpy.zeros((nleft, 0))
for i in range(nleft):
if abs(eig[i]) <= ceps:
ev0 = numpy.hstack((ev0, ev[:,i:i+1]))
else:
ev1 = numpy.hstack((ev1, ev[:,i:i+1]))
def expr_fromvec(vec):
vv = max(abs(vec))
if vv > ceps:
vec = vec / vv
cexpr = Expr.zero()
for i in range(nleft):
if abs(vec[i]) > ceps:
cexpr += var_expr[var_id_inv[i]] * vec[i]
if not skip_simplify:
cexpr.simplify_quick()
return cexpr
mv = ma.dot(ev1)
#print(ev0)
#print(ev1)
r = Region.universe()
for i in range(ev0.shape[1]):
expr = expr_fromvec(ev0[:,i])
#print(expr)
r.iand_norename(expr == 0)
if ev1.shape[1] == 0:
return r
if ev1.shape[1] == 1:
svis = [False, False]
for i in range(len(mv)):
if mv[i, 1] > ceps:
svis[1] = True
if mv[i, 1] < -ceps:
svis[0] = True
if svis[0] and svis[1]:
return r
expr = expr_fromvec(ev1[:,0])
if svis[0]:
if not expr.isnonneg():
r.iand_norename(expr >= 0)
elif svis[1]:
if not expr.isnonpos():
r.iand_norename(expr <= 0)
else:
r.iand_norename(expr == 0)
return r
hull = scipy.spatial.ConvexHull(mv)
#print(len(hull.simplices))
#tone_o = numpy.array(self.get_vec(Expr.H(self.index.comprv - aux)))
#tone = numpy.array([tone_o[var_id_inv[j]] for j in range(nleft)])
rset = set()
for i in range(len(hull.simplices)):
if abs(hull.equations[i,-1]) > ceps:
continue
t = ev1.dot(hull.equations[i,:-1])
vv = max(abs(t))
#print(t)
#print(vv)
if vv > ceps:
t = t / vv
#print(t)
ts = ",".join(iutil.float_tostr(x) for x in t)
if ts not in rset:
rset.add(ts)
expr = expr_fromvec(t)
if not expr.isnonpos():
r.iand_norename(expr <= 0)
if not skip_simplify:
r.simplify_quick()
return r
def get_extreme_rays(self, A = None):
idt = self.id_toexpr()
vs = self.get_extreme_rays_vec(A)
r = RegionOp.union([])
ceps = PsiOpts.settings["eps"]
for v in vs:
tr = Region.universe()
for i in range(len(v)):
if abs(v[i]) > ceps:
te, mask = idt(i)
tr &= te == v[i]
r |= tr
return r
def get_vec(self, x, sparse = False, cat = False):
optobj = SparseMat(self.nvar)
optobj.addrow()
self.addExpr(optobj, x)
optobj.simplify_last_row()
c1 = 0.0
if len(self.constmap):
c1 = optobj.sumrows(self.constmap, remove = True)[0]
if not optobj.mapcol(self.xvarid):
return None, None
optobj.width = self.nxvar
if sparse:
if cat:
optobj.width += 1
if c1 != 0.0:
optobj.x[0].append((self.nvar + 1, c1))
return optobj
else:
return (optobj, c1)
else:
if cat:
return optobj.row_dense(0) + [c1]
else:
return (optobj.row_dense(0), c1)
def call_prog(self, c):
verbose = PsiOpts.settings.get("verbose_lp", False)
ceps = PsiOpts.settings["eps"]
nx = self.dual_form_ncons if self.dual_form else self.nxvar
zero_obj = False
if all(abs(x) <= ceps for x in c):
if not self.affine_present:
return (0.0, [0.0 for i in range(nx)])
c = [0.0] * len(c)
c[0] = 1.0
zero_obj = True
# return (None, None)
if len(self.Au.x) == 0 and len(self.Ae.x) == 0:
return (None, None)
if self.solver == "scipy":
with warnings.catch_warnings():
warnings.simplefilter("ignore")
PsiOpts.stats["numlp"] += 1
res = scipy.optimize.linprog(c, self.solver_param["Aus"], self.bu, self.solver_param["Aes"], self.be,
bounds = (None, None), method = "interior-point", options={'sparse': True})
if res.status == 0:
return (0.0 if zero_obj else float(res.fun), [float(res.x[i]) for i in range(nx)])
return (None, None)
elif self.solver.startswith("pulp."):
prob = self.solver_param["prob"]
xvar = self.solver_param["xvar"]
cexpr = None
for i in range(len(c)):
if abs(c[i]) > PsiOpts.settings["eps"]:
if cexpr is None:
cexpr = c[i] * xvar[str(i)]
else:
cexpr += c[i] * xvar[str(i)]
if cexpr is not None:
prob.setObjective(cexpr)
try:
PsiOpts.stats["numlp"] += 1
res = prob.solve(iutil.pulp_get_solver(self.solver))
except Exception as err:
if verbose:
warnings.warn(str(err), RuntimeWarning)
res = 0
if pulp.LpStatus[res] == "Optimal":
return (0.0 if zero_obj else prob.objective.value(), [xvar[str(i)].value() for i in range(nx)])
return (None, None)
elif self.solver.startswith("pyomo."):
coptions = PsiOpts.get_pyomo_options()
opt = self.solver_param["opt"]
model = self.solver_param["model"]
def o_rule(model):
cexpr = None
for i in range(len(c)):
if abs(c[i]) > PsiOpts.settings["eps"]:
if cexpr is None:
cexpr = c[i] * model.x[i + 1]
else:
cexpr += c[i] * model.x[i + 1]
return cexpr
model.del_component("o")
model.o = pyo.Objective(rule=o_rule)
# print("CALL_PROG START " + str(sum(abs(x) for x in c)))
# print(c)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
try:
PsiOpts.stats["numlp"] += 1
res = opt.solve(model, options = coptions, tee = PsiOpts.settings["verbose_solver"])
except Exception as err:
if verbose:
warnings.warn(str(err), RuntimeWarning)
res = None
# print("CALL_PROG FINISHED")
if (res is not None and res.solver.status == pyo.SolverStatus.ok
and res.solver.termination_condition == pyo.TerminationCondition.optimal):
return (0.0 if zero_obj else model.o(), [model.x[i + 1]() for i in range(nx)])
return (None, None)
elif self.solver.startswith("ortools."):
model = self.solver_param["model"]
xvar = self.solver_param["xvar"]
timelimit = PsiOpts.timer_left_sec()
if timelimit is not None:
model.set_time_limit(timelimit * 1000)
cexpr = None
for i in range(len(c)):
if abs(c[i]) > PsiOpts.settings["eps"]:
if cexpr is None:
cexpr = c[i] * xvar[i]
else:
cexpr += c[i] * xvar[i]
if cexpr is not None:
model.Minimize(cexpr)
try:
PsiOpts.stats["numlp"] += 1
res = model.Solve()
except Exception as err:
if verbose:
warnings.warn(str(err), RuntimeWarning)
res = 0
if res == ortools.linear_solver.pywraplp.Solver.OPTIMAL:
return (0.0 if zero_obj else model.Objective().Value(), [xvar[i].solution_value() for i in range(nx)])
return (None, None)
def checkexpr_ge0(self, x, saved = False, optval = None):
verbose = PsiOpts.settings.get("verbose_lp", False)
save_enabled = PsiOpts.settings.get("lp_save_enabled", False)
if not save_enabled:
saved = False
if self.pinfeas:
return True
if self.eps_present:
x = x.substituted(Expr.eps(), Expr.eps() * (self.lp_eps_obj / self.lp_eps))
zero_cutoff = self.zero_cutoff
if saved:
if False:
optobj = SparseMat(self.nvar)
optobj.addrow()
self.addExpr(optobj, x)
optobj.simplify_last_row()
if len(optobj.x[0]) == 0:
return True
if not optobj.mapcol(self.xvarid):
return False
optobj.width = self.nxvar
optobj, optobjc1 = self.get_vec(x, sparse = True)
if optobj is None:
return False
#cvec = numpy.array(c)
for i in range(len(self.saved_var)):
a = self.saved_var[i]
#if sum(x * y for x, y in zip(c, a)) < zero_cutoff:
#if numpy.dot(a, cvec) < zero_cutoff:
if sum(ca * a[j] for j, ca in optobj.x[0]) + optobjc1 < zero_cutoff:
for j in range(i, 0, -1):
self.saved_var[j], self.saved_var[j - 1] = self.saved_var[j - 1], self.saved_var[j]
return False
return True
verbose_lp_cons = PsiOpts.settings.get("verbose_lp_cons", False)
if verbose_lp_cons:
print("============ LP constraints ============")
print(self.get_region(skip_simplify = True))
print("============ LP objective ============")
if False:
optobj = SparseMat(self.nvar)
optobj.addrow()
self.addExpr(optobj, x)
optobj.simplify_last_row()
optobjc1 = 0.0
if len(self.constmap):
optobjc1 = optobj.sumrows(self.constmap, remove = True)[0]
if not optobj.mapcol(self.xvarid):
return False
optobj.width = self.nxvar
optobj, optobjc1 = self.get_vec(x, sparse = True)
if optobj is None:
print("objective contains new terms")
else:
optreg = self.get_region(A = optobj, skip_simplify = True)
if len(optreg.exprs_ge) > 0:
print(optreg.exprs_ge[0] + optobjc1)
print("========================================")
if self.fcn_mode >= 1 and not self.noskip:
if x.isnonpos_hc():
fcn_res = True
for (a, c) in x.terms:
if not self.checkfcn(a.x[0], a.z):
fcn_res = False
break
#print("FCN " + str(x) + " " + str(fcn_res))
if fcn_res:
if verbose:
print("LP True: fcn")
return True
if self.fcn_mode >= 2:
if verbose:
print("LP False: fcn")
return False
if (self.lptype == LinearProgType.HC1BN or self.lptype == LinearProgType.HMIN) and not self.noskip:
if x.isnonpos_ic2():
if self.bnet.check_ic(x):
if verbose:
print("LP True: bnet")
return True
#return False
res = None
ceps = PsiOpts.settings["eps"]
c, c1 = self.get_vec(x)
if c is None:
pass
else:
#if len([x for x in c if abs(x) > PsiOpts.settings["eps"]]) == 0:
if all(abs(x) <= ceps for x in c):
if c1 >= -ceps:
if verbose:
print("LP True: zero")
if self.dual_form and self.noskip:
self.dual_u = [0.0] * len(self.Au.x)
self.dual_e = [0.0] * len(self.Ae.x)
if self.dual_pf:
self.write_pf(x)
return True
else:
c = None
if c is None:
c = [-ceps * 2] + [0.0] * (self.nxvar - 1)
c1 = -1.0
#return False
if len(self.Au.x) == 0 and len(self.Ae.x) == 0:
if verbose:
print("LP False: no constraints")
return False
if verbose:
print("LP nrv=" + str(self.index.num_rv()) + " nreal=" + str(self.index.num_real())
+ " nvar=" + str(self.Au.width) + "/" + str(self.nvar) + " nineq=" + str(len(self.Au.x))
+ " neq=" + str(len(self.Ae.x)) + " solver=" + self.solver)
if self.solver == "scipy":
rec_limit = 50
if self.Au.width > rec_limit:
warnings.warn("The scipy solver is not recommended for problems of size > "
+ str(rec_limit) + " (current " + str(self.Au.width) + "). "
+ "Please switch to another solver. See "
+ "https://github.com/cheuktingli/psitip#solver", RuntimeWarning)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
PsiOpts.stats["numlp"] += 1
res = scipy.optimize.linprog(c, self.solver_param["Aus"], self.bu, self.solver_param["Aes"], self.be,
bounds = (None, None), method = "interior-point", options={'sparse': True})
if verbose:
print(" status=" + str(res.status) + " optval=" + str(res.fun))
if self.affine_present and self.lp_bounded and res.status == 2:
return True
if res.status == 0:
self.optval = float(res.fun) + c1
if optval is not None:
optval.append(self.optval)
if self.val_enabled:
self.val_x = [0.0] * self.nxvar
for i in range(self.nxvar):
self.val_x[i] = float(res.x[i])
if self.optval >= zero_cutoff:
return True
if res.status == 0 and self.save_res:
self.saved_var.append(array.array("d", [float(a) for a in res.x]))
#self.saved_var.append(numpy.array(list(res.x)))
if verbose:
print(" added : " + str(len(self.saved_var)) + ", " + str(sum(self.saved_var[-1])))
return False
elif self.solver.startswith("pulp."):
if self.dual_form:
self.dual_form_obj = c
self.dual_form_cutoff = c1
self.calc_solver()
if self.dual_form_infeas:
return False
prob = self.solver_param["prob"]
xvar = self.solver_param["xvar"]
cexpr = None
if self.dual_form:
for j in range(self.dual_form_ncons):
tcoeff = self.dual_form_weights[j]
if cexpr is None:
cexpr = tcoeff * xvar[str(j)]
else:
cexpr += tcoeff * xvar[str(j)]
else:
for i in range(len(c)):
if abs(c[i]) > PsiOpts.settings["eps"]:
if cexpr is None:
cexpr = c[i] * xvar[str(i)]
else:
cexpr += c[i] * xvar[str(i)]
if cexpr is not None:
prob.setObjective(cexpr)
try:
PsiOpts.stats["numlp"] += 1
res = prob.solve(iutil.pulp_get_solver(self.solver))
except Exception as err:
if verbose:
warnings.warn(str(err), RuntimeWarning)
res = 0
if verbose:
print(" status=" + pulp.LpStatus[res] + " optval=" + str(prob.objective.value()))
if self.affine_present and self.lp_bounded and (pulp.LpStatus[res] == "Infeasible" or pulp.LpStatus[res] == "Undefined"):
return True
#if pulp.LpStatus[res] == "Infeasible":
# return True
if pulp.LpStatus[res] == "Optimal":
if self.dual_form:
self.dual_u = [0.0] * len(self.Au.x)
self.dual_e = [0.0] * len(self.Ae.x)
for j in range(len(self.Au.x)):
self.dual_u[j] = -xvar[str(j)].value()
for j in range(len(self.Ae.x)):
self.dual_e[j] = -xvar[str(j * 2 + len(self.Au.x))].value() + xvar[str(j * 2 + len(self.Au.x) + 1)].value()
if self.dual_pf:
self.write_pf(x)
return True
self.optval = prob.objective.value() + c1
if optval is not None:
optval.append(self.optval)
if self.dual_enabled:
self.dual_u = [0.0] * len(self.Au.x)
self.dual_e = [0.0] * len(self.Ae.x)
for i, (name, c) in enumerate(prob.constraints.items()):
#print(str(i) + " " + name + " " + str(c.pi))
if c.pi is None:
self.dual_u = None
self.dual_e = None
break
if i < len(self.Au.x):
self.dual_u[i] = c.pi
else:
self.dual_e[i - len(self.Au.x)] = c.pi
if self.dual_pf:
self.write_pf(x)
if self.val_enabled:
self.val_x = [0.0] * self.nxvar
for i in range(self.nxvar):
self.val_x[i] = xvar[str(i)].value()
if self.optval >= zero_cutoff:
return True
if pulp.LpStatus[res] == "Optimal" and self.save_res:
self.saved_var.append(array.array("d", [xvar[str(i)].value() for i in range(len(c))]))
#self.saved_var.append(numpy.array([xvar[str(i)].value() for i in range(len(c))]))
if verbose:
print(" added : " + str(len(self.saved_var)) + ", " + str(sum(self.saved_var[-1])))
return False
else:
return True
elif self.solver.startswith("pyomo."):
if self.dual_form:
self.dual_form_obj = c
self.dual_form_cutoff = c1
self.calc_solver()
if self.dual_form_infeas:
return False
coptions = PsiOpts.get_pyomo_options()
opt = self.solver_param["opt"]
model = self.solver_param["model"]
def o_rule(model):
if self.dual_form:
return sum(model.x[j + 1] * self.dual_form_weights[j] for j in range(self.dual_form_ncons))
else:
cexpr = None
for i in range(len(c)):
if abs(c[i]) > PsiOpts.settings["eps"]:
if cexpr is None:
cexpr = c[i] * model.x[i + 1]
else:
cexpr += c[i] * model.x[i + 1]
return cexpr
model.del_component("o")
model.o = pyo.Objective(rule=o_rule)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
try:
PsiOpts.stats["numlp"] += 1
res = opt.solve(model, options = coptions, tee = PsiOpts.settings["verbose_solver"])
except Exception as err:
if verbose:
warnings.warn(str(err), RuntimeWarning)
res = None
if verbose and res is not None:
print(" status=" + ("OK" if res.solver.status == pyo.SolverStatus.ok else "NO"))
if res is not None and self.affine_present and self.lp_bounded and res.solver.termination_condition == pyo.TerminationCondition.infeasible:
return True
#print("save_res = " + str(self.save_res))
if (res is not None and res.solver.status == pyo.SolverStatus.ok
and res.solver.termination_condition == pyo.TerminationCondition.optimal):
if self.dual_form:
self.dual_u = [0.0] * len(self.Au.x)
self.dual_e = [0.0] * len(self.Ae.x)
for j in range(len(self.Au.x)):
self.dual_u[j] = -model.x[j + 1]()
for j in range(len(self.Ae.x)):
self.dual_e[j] = -model.x[j * 2 + len(self.Au.x) + 1]() + model.x[j * 2 + len(self.Au.x) + 2]()
if self.dual_pf:
self.write_pf(x)
return True
self.optval = model.o() + c1
if optval is not None:
optval.append(self.optval)
if self.dual_enabled:
self.dual_u = [0.0] * len(self.Au.x)
self.dual_e = [0.0] * len(self.Ae.x)
for c in model.component_objects(pyo.Constraint, active=True):
for i, index in enumerate(c):
if i < len(self.Au.x):
self.dual_u[i] = model.dual[c[index]]
else:
self.dual_e[i - len(self.Au.x)] = model.dual[c[index]]
if self.dual_pf:
self.write_pf(x)
if self.val_enabled:
self.val_x = [0.0] * self.nxvar
for i in range(self.nxvar):
self.val_x[i] = model.x[i + 1]()
if self.optval >= zero_cutoff:
#print("RETURN TRUE")
return True
if self.save_res:
self.saved_var.append(array.array("d", [model.x[i + 1]() for i in range(len(c))]))
#self.saved_var.append(numpy.array([model.x[i + 1]() for i in range(len(c))]))
if verbose:
print(" added : " + str(len(self.saved_var)) + ", " + str(sum(self.saved_var[-1])))
return False
elif self.solver.startswith("ortools."):
if self.dual_form:
self.dual_form_obj = c
self.dual_form_cutoff = c1
self.calc_solver()
if self.dual_form_infeas:
return False
model = self.solver_param["model"]
xvar = self.solver_param["xvar"]
timelimit = PsiOpts.timer_left_sec()
if timelimit is not None:
model.set_time_limit(timelimit * 1000)
cexpr = None
if self.dual_form:
for j in range(self.dual_form_ncons):
tcoeff = self.dual_form_weights[j]
if cexpr is None:
cexpr = tcoeff * xvar[j]
else:
cexpr += tcoeff * xvar[j]
else:
for i in range(len(c)):
if abs(c[i]) > PsiOpts.settings["eps"]:
if cexpr is None:
cexpr = c[i] * xvar[i]
else:
cexpr += c[i] * xvar[i]
if cexpr is not None:
model.Minimize(cexpr)
try:
PsiOpts.stats["numlp"] += 1
res = model.Solve()
except Exception as err:
if verbose:
warnings.warn(str(err), RuntimeWarning)
res = 0
if verbose:
print(" status=" + str(res) + " optval=" + str(model.Objective().Value()))
if self.affine_present and self.lp_bounded and (res == ortools.linear_solver.pywraplp.Solver.INFEASIBLE):
return True
#if pulp.LpStatus[res] == "Infeasible":
# return True
if res == ortools.linear_solver.pywraplp.Solver.OPTIMAL:
if self.dual_form:
self.dual_u = [0.0] * len(self.Au.x)
self.dual_e = [0.0] * len(self.Ae.x)
for j in range(len(self.Au.x)):
self.dual_u[j] = -xvar[j].solution_value()
for j in range(len(self.Ae.x)):
self.dual_e[j] = -xvar[j * 2 + len(self.Au.x)].solution_value() + xvar[j * 2 + len(self.Au.x) + 1].solution_value()
if self.dual_pf:
self.write_pf(x)
return True
self.optval = model.Objective().Value() + c1
if optval is not None:
optval.append(self.optval)
# if self.dual_enabled:
# self.dual_u = [0.0] * len(self.Au.x)
# self.dual_e = [0.0] * len(self.Ae.x)
# for i, (name, c) in enumerate(prob.constraints.items()):
# #print(str(i) + " " + name + " " + str(c.pi))
# if c.pi is None:
# self.dual_u = None
# self.dual_e = None
# break
# if i < len(self.Au.x):
# self.dual_u[i] = c.pi
# else:
# self.dual_e[i - len(self.Au.x)] = c.pi
# if self.dual_pf:
# self.write_pf(x)
if self.val_enabled:
self.val_x = [0.0] * self.nxvar
for i in range(self.nxvar):
self.val_x[i] = xvar[i].solution_value()
if self.optval >= zero_cutoff:
return True
if res == ortools.linear_solver.pywraplp.Solver.OPTIMAL and self.save_res:
self.saved_var.append(array.array("d", [xvar[i].solution_value() for i in range(len(c))]))
if verbose:
print(" added : " + str(len(self.saved_var)) + ", " + str(sum(self.saved_var[-1])))
return False
else:
return True
else:
if self.solver == "":
raise RuntimeError("No solver found. Please install a solver. See: "
+ "https://github.com/cheuktingli/psitip#solver")
else:
raise RuntimeError("Solver " + self.solver + " not found. Please install another solver. See: "
+ "https://github.com/cheuktingli/psitip#solver")
def checkexpr_eq0(self, x, saved = False):
return self.checkexpr_ge0(x, saved) and self.checkexpr_ge0(-x, saved)
def checkexpr(self, x, sg, saved = False):
if sg == "==":
return self.checkexpr_eq0(x, saved)
elif sg == ">=":
return self.checkexpr_ge0(x, saved)
elif sg == "<=":
return self.checkexpr_ge0(-x, saved)
else:
return False
def evalexpr_ge0_saved(self, x):
c, c1 = self.get_vec(x)
if c is None:
return -1e8
r = 0.0
for a in self.saved_var:
r += min(sum([x * y for x, y in zip(c, a)]) + c1, 0.0)
return r
def get_val(self, expr):
if self.val_x is None:
return None
c, c1 = self.get_vec(expr, sparse = True)
r = c1
if len(c.x) == 0:
return r
for j, a in c.x[0]:
r += self.val_x[j] * a
return r
def __call__(self, x):
if isinstance(x, Expr):
return self.get_val(x)
elif isinstance(x, Region):
return x.evalcheck(lambda expr: self.get_val(expr))
return None
def __getitem__(self, x):
return self(x)
def clear_dual(self):
self.dual_u = None
self.dual_e = None
def get_dual_expr(self, x):
if self.dual_u is None or self.dual_e is None:
return None
c, c1 = self.get_vec(x, sparse = True)
c.simplify()
#print(c)
r = 0.0
for a, d in itertools.chain(zip(self.Ae.x, self.dual_e), zip(self.Au.x, self.dual_u)):
#print(a)
a2 = SparseMat.from_row(a, self.Au.width)
a2.simplify()
t = a2.ratio(c)
if t is not None:
r += d * t
return r
def get_dual(self, reg):
if isinstance(reg, Expr):
return self.get_dual_expr(reg)
r = []
for x in reg.exprs_ge:
r.append(self.get_dual_expr(x))
for x in reg.exprs_eq:
r.append(self.get_dual_expr(x))
if len(r) == 1:
return r[0]
else:
return r
def proj_hull(prog, n, init_pt = None, toexpr = None, iscone = False, isfrac = None, max_facet = None, num_simplex = None, init_pts_outer = None, pts_outer = None, is_dual_discover = False):
"""Convex hull method for polyhedron projection.
C. Lassez and J.-L. Lassez, Quantifier elimination for conjunctions of linear constraints via a
convex hull algorithm, IBM Research Report, T.J. Watson Research Center, RC 16779 (1991)
"""
if cdd is None:
raise RuntimeError("Convex hull method requires pycddlib. Please install it first.")
verbose = PsiOpts.settings.get("verbose_discover", False)
verbose_ineq = PsiOpts.settings.get("verbose_discover_ineq", False)
verbose_outer = PsiOpts.settings.get("verbose_discover_outer", False)
verbose_detail = PsiOpts.settings.get("verbose_discover_detail", False)
verbose_terms = PsiOpts.settings.get("verbose_discover_terms", False)
verbose_terms_inner = PsiOpts.settings.get("verbose_discover_terms_inner", False)
verbose_terms_outer = PsiOpts.settings.get("verbose_discover_terms_outer", False)
verbose_aux_reduce = PsiOpts.settings.get("verbose_aux_reduce", False)
#max_denom = PsiOpts.settings.get("max_denom", 1000)
facet_enumerate = PsiOpts.settings.get("discover_hull_facet_enumerate", False)
facet_enumerate_numyield = PsiOpts.settings.get("discover_hull_facet_enumerate_numyield", 20)
if isfrac is None:
isfrac = PsiOpts.settings.get("discover_hull_frac_enabled", False)
frac_denom = PsiOpts.settings.get("discover_hull_frac_denom", -1)
if max_facet is None:
max_facet = PsiOpts.settings.get("discover_max_facet", 10000000000)
if num_simplex is None:
num_simplex = PsiOpts.settings.get("discover_num_simplex", 1)
ceps = PsiOpts.settings["eps_lp"]
linear_facet_only = iscone
rnd_started = False
mat = None
if init_pts_outer is not None and len(init_pts_outer) >= 2:
cen = [0.0] * n
for p in init_pts_outer:
for i in range(n):
cen[i] += p[i]
for i in range(n):
cen[i] /= len(init_pts_outer)
init_pts = []
did = False
for p in init_pts_outer:
v = [cenx - px for cenx, px in zip(cen, p)]
if all(abs(vx) <= ceps for vx in v):
continue
val, a = prog(v)
if not all(abs(ax - px) <= ceps for ax, px in zip(a, p)):
did = True
init_pts.append(a)
if pts_outer is not None:
pts_outer.append(a)
if not did:
if verbose_aux_reduce:
print("discover fails to shrink")
return None
# mat = cdd.Matrix([[1] + p for p in init_pts], number_type=("fraction" if isfrac else "float"))
mat = cdd_matrix_from_array([[1] + p for p in init_pts], number_type=("fraction" if isfrac else "float"), rep_type = cdd.RepType.GENERATOR)
else:
if init_pt is None:
if iscone:
init_pt = [0] * n
else:
_, init_pt = prog([1] * n)
if init_pt is None:
init_pt = [0] * n
if pts_outer is not None:
pts_outer.append(init_pt)
# mat = cdd.Matrix([[1] + init_pt], number_type=("fraction" if isfrac else "float"))
mat = cdd_matrix_from_array([[1] + init_pt], number_type=("fraction" if isfrac else "float"), rep_type = cdd.RepType.GENERATOR)
# mat.rep_type = cdd.RepType.GENERATOR
ineqs_tight = []
ineqs_tried = []
pts_tight = []
matP = None
matP_avg = None
matQ = None
matQ_null = None
poly = None
did = True
while did:
if PsiOpts.is_timer_ended():
break
if verbose:
print("NUMPOINT = " + str(cdd_row_size(mat)) + " NUMDIM = " + str(cdd_col_size(mat)), flush = True)
did = False
tgt_num_point = cdd_col_size(mat) + num_simplex
isfull = max_facet is None or (max_facet > 0 and cdd_col_size(mat) * 0.5 * numpy.log(cdd_row_size(mat)) <= numpy.log(max_facet)) # or tgt_num_point >= cdd_row_size(mat)
poly = None
ineqs = []
lset = set()
ineqs_row_size = 0
ineqs_gen = None
ineqs_isgen = False
if isfull:
ineqs, lset = cdd_convert(mat)
ineqs_row_size = len(ineqs)
# poly = cdd.Polyhedron(mat)
# # print("MAT:")
# # print(mat)
# ineqs = poly.get_inequalities()
# # print("INEQ:")
# # print(ineqs)
# lset = ineqs.lin_set
# ineqs_row_size = ineqs.row_size
elif facet_enumerate and iscone and ortools is not None:
if not rnd_started:
if not PsiOpts.has_timer():
warnings.warn("Convex hull method: Max number of facets discover_max_facet = " + str(max_facet) +
" reached (current = " + str(round(numpy.exp(cdd_col_size(mat) * 0.5 * numpy.log(cdd_row_size(mat))))) + "). Switching to randomized facet enumeration. Program will not terminate unless the block is enclosed by \"with PsiOpts(timelimit = ???):\" or \"with PsiOpts(stop_file = ???):\".", RuntimeWarning)
rnd_started = True
if verbose:
print("FACET ENUMERATE")
matP2 = numpy.array([[float(cdd_mat_array(mat)[i][j]) for j in range(1, cdd_col_size(mat))] for i in range(cdd_row_size(mat))], ndmin = 2)
def tgen():
for x in iutil.conic_hull_facets(matP2, numyield = facet_enumerate_numyield, mode = "two_connected"):
yield numpy.array([0.0] + list(x))
ineqs_gen = tgen()
ineqs_isgen = True
did = True
else:
if not rnd_started:
if not PsiOpts.has_timer():
warnings.warn("Convex hull method: Max number of facets discover_max_facet = " + str(max_facet) +
" reached (current = " + str(round(numpy.exp(cdd_col_size(mat) * 0.5 * numpy.log(cdd_row_size(mat))))) + "). Switching to randomized subset. Program will not terminate unless the block is enclosed by \"with PsiOpts(timelimit = ???):\" or \"with PsiOpts(stop_file = ???):\".", RuntimeWarning)
rnd_started = True
rnd = PsiOpts.get_random()
if matP is None:
# tgt_num_point = rnd.randrange(1, cdd_col_size(mat) + 1)
matP = numpy.array([[float(cdd_mat_array(mat)[i][j]) for j in range(1, cdd_col_size(mat))] for i in range(cdd_row_size(mat))], ndmin = 2)
# print(matP, flush = True)
matP_avg = numpy.mean(matP, 0)
# print(matP_avg, flush = True)
matQ = numpy.array([matP[i,:] - matP_avg for i in range(cdd_row_size(mat))])
matQ_null = scipy.linalg.null_space(matQ)
# print(matQ_null, flush = True)
# currank = numpy.linalg.matrix_rank(matQ)
currank = matQ.shape[1] - matQ_null.shape[1]
min_num_point = max(currank, 1)
max_num_point = min(matP.shape)
# tgt_num_point = rnd.randrange(min_num_point, max_num_point + 1)
tgt_num_point = min_num_point
if verbose:
print("RANDOM SUBSET NUMPOINT = " + str(tgt_num_point), flush = True)
cid = list(range(cdd_row_size(mat)))
rnd.shuffle(cid)
if iscone:
matA = numpy.array([[0.0] * matQ.shape[1]] + [matQ[cid[i],:] for i in range(tgt_num_point - 1)], ndmin = 2)
else:
matA = numpy.array([matQ[cid[i],:] for i in range(tgt_num_point)], ndmin = 2)
t = numpy.linalg.lstsq(matA, numpy.ones(tgt_num_point), rcond = None)[0]
# print(t.shape, flush = True)
ineqs.append([matP_avg.dot(t) + 1.0] + list(-t))
for i in range(matQ_null.shape[1]):
t = matQ_null[:, i]
ineqs.append([matP_avg.dot(t)] + list(-t))
lset.add(len(ineqs) - 1)
ineqs_row_size = len(ineqs)
if False:
cmat = cdd.Matrix([cdd_mat_array(mat)[cid[i]] for i in range(tgt_num_point)], number_type=("fraction" if isfrac else "float"))
poly = cdd.Polyhedron(cmat)
ineqs = poly.get_inequalities()
lset = ineqs.lin_set
ineqs_row_size = ineqs.row_size
did = True
if verbose:
print("HULL FINISHED", flush = True)
if ineqs_gen is None:
def tgen():
for i in range(ineqs_row_size):
yield ineqs[i]
ineqs_gen = tgen()
if verbose_terms or verbose_terms_inner:
if is_dual_discover:
pass
else:
print("INNER:", flush = True)
if not ineqs_isgen:
for i in range(ineqs_row_size):
y = ineqs[i]
print(" " + str(toexpr(y[1:n+1])) + (" == " if i in lset else " >= ") + iutil.float_tostr(-y[0]), flush = True)
for i, tx2 in enumerate(ineqs_gen):
if PsiOpts.is_timer_ended():
break
if linear_facet_only and abs(tx2[0]) > ceps:
if verbose_detail:
x = tx2
print("SKIP AFFINE " + str(toexpr(x[1:n+1])) + " + " + str(x[0]) + " >= 0", flush = True)
continue
for sgn in ([1, -1] if i in lset else [1]):
if PsiOpts.is_timer_ended():
break
#print("IROWSIZE " + str(ineqs_row_size))
x = [a * sgn for a in tx2]
xnorm = sum(abs(a) for a in x)
if xnorm <= ceps:
continue
x = [a / xnorm for a in x]
if not isfull:
isbad = False
for i2 in range(cdd_row_size(mat)):
y = cdd_mat_array(mat)[i2]
if sum(yt * xt for (yt, xt) in zip(y, x)) < -ceps:
isbad = True
break
if isbad:
if verbose_detail:
if is_dual_discover:
pass
else:
print("NOT EXTREMAL " + str(toexpr(x[1:n+1])) + " + " + str(x[0]) + " >= 0", flush = True)
continue
for y in ineqs_tried:
if sum(abs(a - b) for a, b in zip(x, y)) <= ceps:
break
else:
if verbose_detail:
print("MIN " + str(toexpr(x[1:n+1])), flush = True)
isshortcut = False
opt = None
v = None
if not is_dual_discover and toexpr is not None and abs(x[0]) <= ceps:
if toexpr(x[1:n+1]).simplified_quick().isnonneg():
opt = 0.0
isshortcut = True
if not isshortcut:
# print("PROG " + str(x))
opt, v = prog(x[1:n+1])
# print(" VS " + str(opt) + " " + str(-x[0]))
vo = v
ineqs_tried.append(list(x))
if verbose_detail:
print("MIN FINISHED", flush = True)
if opt is None or opt >= -x[0] - ceps:
ctoadd = opt is not None
# ctoadd = True
if ctoadd:
ineqs_tight.append(list(x))
if verbose or verbose_terms_outer or verbose_ineq:
if verbose and opt is None:
print("NONE", flush = True)
# print(x)
if verbose_terms_outer or verbose_outer or verbose_ineq or abs(x[0]) <= 100:
if ctoadd:
if is_dual_discover:
if verbose_detail:
print("DPT " + str(x), flush = True)
else:
if verbose_ineq:
print(str(toexpr(x[1:n+1])) + " >= " + iutil.float_tostr(-x[0]), flush = True)
if verbose:
print("ADD " + str(toexpr(x[1:n+1])) + " >= " + iutil.float_tostr(-x[0]) + (" SHORTCUT" if isshortcut else ""), flush = True)
if verbose_terms or verbose_terms_outer:
print("OUTER:", flush = True)
print(alland([toexpr(y[1:n+1]) + y[0] >= 0 for y in ineqs_tight]), flush = True)
print()
# for y in ineqs_tight:
# print(" " + str(toexpr(y[1:n+1])) + " >= " + iutil.float_tostr(-y[0]))
#print("TIGHT " + str(list(x)))
continue
if isfrac:
if frac_denom > 0:
v = [fractions.Fraction(a).limit_denominator(frac_denom) for a in v]
else:
v = [fractions.Fraction(a) for a in v]
if iscone:
vnorm = sum(abs(a) for a in v)
if vnorm > ceps:
v = [a / vnorm for a in v]
v = [0] + v
else:
v = [1] + v
for i2 in range(cdd_row_size(mat)):
y = cdd_mat_array(mat)[i2]
if sum(abs(a - b) for a, b in zip(v, y)) <= ceps:
break
else:
if is_dual_discover:
if verbose_ineq:
print(str(-toexpr(v[1:n+1])) + " >= " + iutil.float_tostr(v[0]), flush = True)
if verbose:
print("DADD " + str(-toexpr(v[1:n+1])) + " >= " + iutil.float_tostr(v[0]), flush = True)
else:
if verbose_detail:
print("PT " + str(v), flush = True)
if pts_outer is not None:
pts_outer.append(vo)
cdd_extend(mat, [v])
if is_dual_discover:
pts_tight.append([-a for a in v])
matP = None
did = True
# if pts_outer is not None:
# gens = poly.get_generators()
# pts_outer[:] = [[float(z) for z in y[1:]] for y in gens]
if is_dual_discover:
return pts_tight
else:
return ineqs_tight
def discover_hull(self, A, iscone = False, init_pts_outer = None, pts_outer = None):
"""Convex hull method for polyhedron projection.
C. Lassez and J.-L. Lassez, Quantifier elimination for conjunctions of linear constraints via a
convex hull algorithm, IBM Research Report, T.J. Watson Research Center, RC 16779 (1991)
"""
verbose = PsiOpts.settings.get("verbose_discover", False)
verbose_detail = PsiOpts.settings.get("verbose_discover_detail", False)
verbose_terms = PsiOpts.settings.get("verbose_discover_terms", False)
verbose_terms_outer = PsiOpts.settings.get("verbose_discover_terms_outer", False)
n = self.nxvar
m = len(A.x)
toexpr = None
if True or verbose or verbose_detail or verbose_terms or verbose_terms_outer:
itoexpr = self.row_toexpr()
def ctoexpr(x):
c = [0.0] * n
for i in range(m):
for j, a in A.x[i]:
c[j] += a * x[i]
return itoexpr(c)
toexpr = ctoexpr
#print(A.x)
#print(self.Au.x)
def cprog(x):
c = [0.0] * n
for i in range(m):
for j, a in A.x[i]:
c[j] += a * x[i]
opt, v = self.call_prog(c)
if opt is None:
return (None, None)
r = [0.0] * m
for i in range(m):
for j, a in A.x[i]:
r[i] += a * v[j]
return (opt, r)
return LinearProg.proj_hull(cprog, m, toexpr = toexpr, iscone = iscone, init_pts_outer = init_pts_outer, pts_outer = pts_outer)
def corners_value(self, ispolar = False, isfrac = None):
if cdd is None:
raise RuntimeError("Requires pycddlib. Please install it first.")
if isfrac is None:
isfrac = PsiOpts.settings.get("discover_hull_frac_enabled", False)
frac_denom = PsiOpts.settings.get("discover_hull_frac_denom", -1)
ceps = PsiOpts.settings["eps_lp"]
n = self.nxvar
# print(n)
# print(self.Au.x)
# print(self.Ae.x)
mat = None
mat_array = None
lset = set()
clen = 0
if len(self.Au.x):
ma = numpy.hstack([numpy.array([self.bu]).T, -self.Au.tonumpyarray()])
if mat_array is None:
mat_array = ma
else:
mat_array = numpy.vstack([mat_array, ma])
clen += len(self.Au.x)
if len(self.Ae.x):
ma = numpy.hstack([numpy.array([self.be]).T, -self.Ae.tonumpyarray()])
if mat_array is None:
mat_array = ma
else:
mat_array = numpy.vstack([mat_array, ma])
for i in range(clen, clen + len(self.Ae.x)):
lset.add(i)
clen += len(self.Ae.x)
if mat_array is None:
return None
mat = cdd_matrix_from_array(ma, lin_set=lset, number_type=("fraction" if isfrac else "float"), rep_type = cdd.RepType.INEQUALITY)
# if len(self.Au.x):
# ma = numpy.hstack([numpy.array([self.bu]).T, -self.Au.tonumpyarray()])
# if mat is None:
# # mat = cdd.Matrix(ma, number_type=("fraction" if isfrac else "float"))
# # mat.rep_type = cdd.RepType.INEQUALITY
# mat = cdd_matrix_from_array(ma, number_type=("fraction" if isfrac else "float"), rep_type = cdd.RepType.INEQUALITY)
# else:
# mat.extend(ma)
# if len(self.Ae.x):
# ma = numpy.hstack([numpy.array([self.be]).T, -self.Ae.tonumpyarray()])
# if mat is None:
# # mat = cdd.Matrix(ma, linear = True, number_type=("fraction" if isfrac else "float"))
# # mat.rep_type = cdd.RepType.INEQUALITY
# mat = cdd_matrix_from_array(ma, linear = True, number_type=("fraction" if isfrac else "float"), rep_type = cdd.RepType.INEQUALITY)
# else:
# mat.extend(ma, linear = True)
if mat is None:
return None
# print(mat)
poly = cdd.Polyhedron(mat)
gs = poly.get_generators()
fs = poly.get_inequalities()
# print(gs)
gi = poly.get_incidence()
fi = poly.get_input_incidence()
#print(fs)
if ispolar:
gs, fs = fs, gs
gi, fi = fi, gi
ng = gs.row_size
angles = [0.0] * ng
if n >= 2:
for i in range(ng):
avg = None
if ispolar:
avg = gs[i][1:]
else:
avg = [0.0] * n
for k in gi[i]:
avg = [fs[k][j + 1] + avg[j] for j in range(n)]
angles[i] = math.atan2(avg[1], avg[0])
if len(gi[i]) == 0:
angles[i] = 1e20
gsj = sorted([i for i in range(ng) if len(gi[i])], key = lambda k: angles[k])
gsjinv = [0] * ng
for i in range(len(gsj)):
gsjinv[gsj[i]] = i
gsr = [list(gs[gsj[i]]) for i in range(len(gsj))]
#fir = [[gsjinv[a] for a in x] for x in fi if len(x) > 0 and len(x) < ng]
fir = [[gsjinv[a] for a in x] for x in fi if len(x) > 0]
return (gsr, fir)
def table(self, *args, **kwargs):
"""Plot the information diagram as a Karnaugh map.
"""
return universe().table(*args, self, **kwargs)
def venn(self, *args, **kwargs):
"""Plot the information diagram as a Venn diagram.
Can handle up to 5 random variables (uses Branko Grunbaum's Venn diagram for n=5).
"""
return universe().venn(*args, self, **kwargs)
class Level(IBaseObj):
"""The level of a region in the linear entropy hierarchy.
"""
def __init__(self, quant = 0, n = 0):
self.quant = quant
self.n = n
if self.n == 0:
self.quant = 0
@staticmethod
def sigma(n):
"""
Sigma_n.
Returns
-------
Level
"""
return Level(1 if n else 0, n)
@staticmethod
def pi(n):
"""
Pi_n.
Returns
-------
Level
"""
return Level(-1 if n else 0, n)
@staticmethod
def delta(n):
"""
Delta_n.
Returns
-------
Level
"""
return Level(0, n)
def __eq__(self, other):
return self.quant == other.quant and self.n == other.n
def __ne__(self, other):
return not self == other
def __and__(self, other):
if self.n < other.n:
return self
if other.n < self.n:
return other
if self == other:
return self
return Level(0, self.n)
def __or__(self, other):
if self.n > other.n:
return self
if other.n > self.n:
return other
if self == other:
return self
if self.quant == 0:
return other
if other.quant == 0:
return self
return Level(0, self.n + 1)
def __invert__(self):
return Level(-self.quant, self.n)
def __rshift__(self, other):
return (~self) | other
def qimplies(self, other, q = 0):
if q > 0:
return (self | other).exists()
elif q < 0:
return (self >> other).forall()
else:
return (self | other).exists() & (self >> other).forall()
def __bool__(self):
return self.n != 0
def __int__(self):
return self.n
def int_lv(self):
return self.n * 2 + int(self.quant != 0)
def __ge__(self, other):
return self.int_lv() > other.int_lv() or self == other
def __le__(self, other):
return other >= self
def __gt__(self, other):
return self.int_lv() > other.int_lv()
def __lt__(self, other):
return other > self
def exists(self, forall = False):
tquant = -1 if forall else 1
if self.n and (self.quant == tquant or self.quant == 0):
return Level(tquant, self.n)
else:
return Level(tquant, self.n + 1)
def forall(self):
return self.exists(True)
def __str__(self):
r = ""
if self.quant == 0:
r += "Delta"
elif self.quant > 0:
r += "Sigma"
else:
r += "Pi"
r += "_" + str(self.n)
return r
@latex_postprocess
def _latex_(self):
r = ""
if self.quant == 0:
r += r"\Delta"
elif self.quant > 0:
r += r"\Sigma"
else:
r += r"\Pi"
r += "_"
t = str(self.n)
if len(t) > 1:
r += "{" + t + "}"
else:
r += t
return r
def __repr__(self):
r = "Level."
if self.quant == 0:
r += "delta"
elif self.quant > 0:
r += "sigma"
else:
r += "pi"
r += "(" + str(self.n) + ")"
return r
class RegionLevel(IBaseObj):
"""The level of a region in the linear entropy hierarchy, with the same syntax as Region.
"""
def __init__(self, quant = 0, n = 0):
if isinstance(quant, Level):
self.level = quant
elif isinstance(quant, Region):
self.level = quant.level()
elif isinstance(quant, RegionLevel):
self.level = quant.level
else:
self.level = Level(quant, n)
@staticmethod
def sigma(n):
"""
Sigma_n.
Returns
-------
Level
"""
return RegionLevel(1 if n else 0, n)
@staticmethod
def pi(n):
"""
Pi_n.
Returns
-------
Level
"""
return RegionLevel(-1 if n else 0, n)
@staticmethod
def delta(n):
"""
Delta_n.
Returns
-------
Level
"""
return RegionLevel(0, n)
def __or__(self, other):
if not isinstance(self, RegionLevel):
self = self.region_level()
if not isinstance(other, RegionLevel):
other = other.region_level()
return RegionLevel(self.level | other.level)
def __and__(self, other):
return self | other
def inter(self, other):
if not isinstance(self, RegionLevel):
self = self.region_level()
if not isinstance(other, RegionLevel):
other = other.region_level()
return RegionLevel(self.level & other.level)
def __invert__(self):
if not isinstance(self, RegionLevel):
self = self.region_level()
return RegionLevel(~self.level)
def __rshift__(self, other):
if not isinstance(self, RegionLevel):
self = self.region_level()
if not isinstance(other, RegionLevel):
other = other.region_level()
return (~self) | other
def qimplies(self, other, q = 0):
if not isinstance(self, RegionLevel):
self = self.region_level()
if not isinstance(other, RegionLevel):
other = other.region_level()
return RegionLevel(self.level.qimplies(other.level, q = q))
def __eq__(self, other):
if not isinstance(self, RegionLevel):
self = self.region_level()
if not isinstance(other, RegionLevel):
other = other.region_level()
return (self >> other) & (other >> self)
def __ne__(self, other):
return ~(self == other)
def __bool__(self):
return bool(self.level)
def __int__(self):
return int(self.level)
def int_lv(self):
return self.level.int_lv()
def __ge__(self, other):
if not isinstance(self, RegionLevel):
self = self.region_level()
if not isinstance(other, RegionLevel):
other = other.region_level()
return self.level >= other.level
def __le__(self, other):
if not isinstance(self, RegionLevel):
self = self.region_level()
if not isinstance(other, RegionLevel):
other = other.region_level()
return self.level <= other.level
def __gt__(self, other):
if not isinstance(self, RegionLevel):
self = self.region_level()
if not isinstance(other, RegionLevel):
other = other.region_level()
return self.level > other.level
def __lt__(self, other):
if not isinstance(self, RegionLevel):
self = self.region_level()
if not isinstance(other, RegionLevel):
other = other.region_level()
return self.level < other.level
def exists(self, *args, forall = False):
return RegionLevel(self.level.exists(forall = forall))
def forall(self, *args):
return self.exists(*args, forall = True)
def __str__(self):
return str(self.level)
@latex_postprocess
def _latex_(self):
return self.level._latex_()
def __repr__(self):
r = "RegionLevel."
if self.level.quant == 0:
r += "delta"
elif self.level.quant > 0:
r += "sigma"
else:
r += "pi"
r += "(" + str(self.level.n) + ")"
return r
class RegionType:
NIL = 0
NORMAL = 1
UNION = 2
INTER = 3
class Region(IBaseObj):
"""A region consisting of equality and inequality constraints"""
def __init__(self, exprs_ge, exprs_eq, aux, inp, oup, exprs_gei = None, exprs_eqi = None, auxi = None, meta = None):
self.exprs_ge = exprs_ge
self.exprs_eq = exprs_eq
self.aux = aux
self.inp = inp
self.oup = oup
if exprs_gei is not None:
self.exprs_gei = exprs_gei
else:
self.exprs_gei = []
if exprs_eqi is not None:
self.exprs_eqi = exprs_eqi
else:
self.exprs_eqi = []
if auxi is not None:
self.auxi = auxi
else:
self.auxi = Comp.empty()
self.meta = meta
def get_type(self):
return RegionType.NORMAL
def isnormalcons(self):
return not self.imp_present()
@staticmethod
def universe():
return Region([], [], Comp.empty(), Comp.empty(), Comp.empty())
@staticmethod
def Ic(x, y, z = None):
if z is None:
z = Comp.empty()
x = x - z
y = y - z
if x.isempty() or y.isempty():
return Region.universe()
return Region([], [Expr.Ic(x, y, z)], Comp.empty(), Comp.empty(), Comp.empty())
@staticmethod
def empty():
return Region([-Expr.one()], [], Comp.empty(), Comp.empty(), Comp.empty())
@staticmethod
def from_bool(b):
if b:
return Region.universe()
else:
return Region.empty()
@staticmethod
def parse(s):
"""Parse a string, e.g. I(X;Y,Z|W) + 2H(X Z) le 3
"""
return iutil.ensure_region(RegionParser.parse_default(s))
def setuniverse(self):
self.exprs_ge = []
self.exprs_eq = []
self.aux = Comp.empty()
self.inp = Comp.empty()
self.oup = Comp.empty()
self.exprs_gei = []
self.exprs_eqi = []
self.auxi = Comp.empty()
def setempty(self):
self.exprs_ge = [-Expr.one()]
self.exprs_eq = []
self.aux = Comp.empty()
self.inp = Comp.empty()
self.oup = Comp.empty()
self.exprs_gei = []
self.exprs_eqi = []
self.auxi = Comp.empty()
def isempty(self):
if not (len(self.exprs_gei) == 0 and len(self.exprs_eqi) == 0):
return False
ceps = PsiOpts.settings["eps"]
for x in self.exprs_ge:
t = x.get_const()
if t is not None and t < -ceps:
return True
for x in self.exprs_eq:
t = x.get_const()
if t is not None and abs(t) > ceps:
return True
return False
def isuniverse(self, sgn = True, canon = False):
if canon and (not self.aux.isempty() or not self.auxi.isempty()):
return False
if sgn:
return len(self.exprs_ge) == 0 and len(self.exprs_eq) == 0 and len(self.exprs_gei) == 0 and len(self.exprs_eqi) == 0
else:
return self.isempty()
def iseq(self):
"""Is this pure equality.
"""
if not self.aux.isempty() or not self.auxi.isempty():
return False
return len(self.exprs_ge) == 0 and len(self.exprs_eq) > 0 and len(self.exprs_gei) == 0 and len(self.exprs_eqi) == 0
def isineq(self):
"""Is this pure inequality.
"""
if not self.aux.isempty() or not self.auxi.isempty():
return False
return len(self.exprs_ge) > 0 and len(self.exprs_eq) == 0 and len(self.exprs_gei) == 0 and len(self.exprs_eqi) == 0
def expr(self):
"""Returns the sum of the expressions in this region.
"""
return sum(self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi, Expr.zero())
def copy(self):
return Region([x.copy() for x in self.exprs_ge],
[x.copy() for x in self.exprs_eq],
self.aux.copy(), self.inp.copy(), self.oup.copy(),
[x.copy() for x in self.exprs_gei],
[x.copy() for x in self.exprs_eqi],
self.auxi.copy(),
iutil.copy(self.meta))
def copy_noaux(self):
return Region([x.copy() for x in self.exprs_ge],
[x.copy() for x in self.exprs_eq],
Comp.empty(), self.inp.copy(), self.oup.copy(),
[x.copy() for x in self.exprs_gei],
[x.copy() for x in self.exprs_eqi],
Comp.empty(),
iutil.copy(self.meta))
def noaux(self):
return self.copy_noaux()
def copy_(self, other):
self.exprs_ge = [x.copy() for x in other.exprs_ge]
self.exprs_eq = [x.copy() for x in other.exprs_eq]
self.aux = other.aux.copy()
self.inp = other.inp.copy()
self.oup = other.oup.copy()
self.exprs_gei = [x.copy() for x in other.exprs_gei]
self.exprs_eqi = [x.copy() for x in other.exprs_eqi]
self.auxi = other.auxi.copy()
self.meta = iutil.copy(other.meta)
def imp_intersection(self):
return Region([x.copy() for x in self.exprs_ge] + [x.copy() for x in self.exprs_gei],
[x.copy() for x in self.exprs_eq] + [x.copy() for x in self.exprs_eqi],
self.aux.copy() + self.auxi.copy(), self.inp.copy(), self.oup.copy())
def imp_intersection_noaux(self):
return Region([x.copy() for x in self.exprs_ge] + [x.copy() for x in self.exprs_gei],
[x.copy() for x in self.exprs_eq] + [x.copy() for x in self.exprs_eqi],
Comp.empty(), Comp.empty(), Comp.empty())
def imp_copy(self):
return Region([],
[],
Comp.empty(), Comp.empty(), Comp.empty(),
[x.copy() for x in self.exprs_gei],
[x.copy() for x in self.exprs_eqi],
self.auxi.copy())
def imp_flipped(self):
return Region([x.copy() for x in self.exprs_gei],
[x.copy() for x in self.exprs_eqi],
self.auxi.copy(), self.inp.copy(), self.oup.copy(),
[x.copy() for x in self.exprs_ge],
[x.copy() for x in self.exprs_eq],
self.aux.copy())
def consonly(self):
return Region([x.copy() for x in self.exprs_ge],
[x.copy() for x in self.exprs_eq],
self.aux.copy(), self.inp.copy(), self.oup.copy())
def imp_flippedonly(self):
return Region([x.copy() for x in self.exprs_gei],
[x.copy() for x in self.exprs_eqi],
self.auxi.copy(), Comp.empty(), Comp.empty())
def imp_flippedonly_noaux(self):
return Region([x.copy() for x in self.exprs_gei],
[x.copy() for x in self.exprs_eqi],
Comp.empty(), Comp.empty(), Comp.empty())
def imp_present(self):
return len(self.exprs_gei) > 0 or len(self.exprs_eqi) > 0 or not self.auxi.isempty()
def imp_flip(self):
self.exprs_ge, self.exprs_gei = self.exprs_gei, self.exprs_ge
self.exprs_eq, self.exprs_eqi = self.exprs_eqi, self.exprs_eq
self.aux, self.auxi = self.auxi, self.aux
return self
def imp_only_copy_to(self, other):
other.exprs_ge = []
other.exprs_eq = []
other.aux = Comp.empty()
other.inp = Comp.empty()
other.oup = Comp.empty()
other.exprs_gei = [x.copy() for x in self.exprs_gei]
other.exprs_eqi = [x.copy() for x in self.exprs_eqi]
other.auxi = self.auxi.copy()
def __len__(self):
return len(self.exprs_ge) + len(self.exprs_eq)
def __getitem__(self, key):
t = [(a, False) for a in self.exprs_ge] + [(a, True) for a in self.exprs_eq]
r = t[key]
if not isinstance(r, list):
r = [r]
c = Region.universe()
for a, eq in r:
if eq:
c.exprs_eq.append(a)
else:
c.exprs_ge.append(a)
return c
def add_meta(self, key, value, children = True):
if not children:
return IBaseObj.add_meta(self, key, value)
for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi:
x.add_meta(key, value)
return self
def get_meta(self, key):
t = IBaseObj.get_meta(self, key)
if t is not None:
return t
for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi:
t = x.get_meta(key)
if t is not None:
return t
return None
def remove_meta(self, key):
IBaseObj.remove_meta(self, key)
for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi:
x.remove_meta(key)
return self
def sum_entrywise(self, other):
return Region([x + y for (x, y) in zip(self.exprs_ge, other.exprs_ge)],
[x + y for (x, y) in zip(self.exprs_eq, other.exprs_eq)],
self.aux.interleaved(other.aux), self.inp.interleaved(other.inp), self.oup.interleaved(other.oup),
[x + y for (x, y) in zip(self.exprs_gei, other.exprs_gei)],
[x + y for (x, y) in zip(self.exprs_eqi, other.exprs_eqi)],
self.auxi.interleaved(other.auxi))
def ispresent(self, x):
"""Return whether any variable in x appears here."""
for z in self.exprs_ge:
if z.ispresent(x):
return True
for z in self.exprs_eq:
if z.ispresent(x):
return True
if self.aux.ispresent(x):
return True
if self.inp.ispresent(x):
return True
if self.oup.ispresent(x):
return True
for z in self.exprs_gei:
if z.ispresent(x):
return True
for z in self.exprs_eqi:
if z.ispresent(x):
return True
if self.auxi.ispresent(x):
return True
return False
def __contains__(self, other):
x = other
if isinstance(x, str):
if x == "=" or x == "==":
return bool(self.exprs_eq) or bool(self.exprs_eqi)
if x == ">=" or x == "<=":
return bool(self.exprs_ge) or bool(self.exprs_gei)
if x == ">" or x == "<":
return Expr.eps() in self
if x == "exists":
return not self.aux.isempty()
if x == "forall":
return not self.auxi.isempty()
x = rv(x)
if isinstance(x, Region):
return self.contains_region(other)
for z in self.exprs_ge:
if x in z:
return True
for z in self.exprs_eq:
if x in z:
return True
if x in self.aux:
return True
if x in self.inp:
return True
if x in self.oup:
return True
for z in self.exprs_gei:
if x in z:
return True
for z in self.exprs_eqi:
if x in z:
return True
if x in self.auxi:
return True
return False
def affine_present(self):
"""Return whether there are any affine constraint."""
return self.ispresent((Expr.one() + Expr.eps() + Expr.inf()).allcomp())
def imp_ispresent(self, x):
for z in self.exprs_gei:
if z.ispresent(x):
return True
for z in self.exprs_eqi:
if z.ispresent(x):
return True
if self.auxi.ispresent(x):
return True
return False
def contains_region(self, other):
if isinstance(other, RegionOp):
return False
for x, isge in [(y, True) for y in other.exprs_ge] + [(y, False) for y in other.exprs_eq]:
found = False
if isge:
for z in self.exprs_ge + self.exprs_gei:
t = z.get_ratio(x)
if t is not None and t > 0:
found = True
break
if found:
continue
for z in self.exprs_eq + self.exprs_eqi:
t = z.get_ratio(x)
if t is not None and t != 0:
found = True
break
if found:
continue
return False
return True
def rename_var(self, name0, name1):
for x in self.exprs_ge:
x.rename_var(name0, name1)
for x in self.exprs_eq:
x.rename_var(name0, name1)
self.aux.rename_var(name0, name1)
self.inp.rename_var(name0, name1)
self.oup.rename_var(name0, name1)
for x in self.exprs_gei:
x.rename_var(name0, name1)
for x in self.exprs_eqi:
x.rename_var(name0, name1)
self.auxi.rename_var(name0, name1)
def rename_map(self, namemap):
"""Rename according to name map.
"""
for x in self.exprs_ge:
x.rename_map(namemap)
for x in self.exprs_eq:
x.rename_map(namemap)
self.aux.rename_map(namemap)
self.inp.rename_map(namemap)
self.oup.rename_map(namemap)
for x in self.exprs_gei:
x.rename_map(namemap)
for x in self.exprs_eqi:
x.rename_map(namemap)
self.auxi.rename_map(namemap)
return self
@fcn_substitute
def substitute(self, v0, v1):
"""Substitute variable v0 by v1 (v1 can be compound), in place"""
for x in self.exprs_ge:
x.substitute(v0, v1)
for x in self.exprs_eq:
x.substitute(v0, v1)
for x in self.exprs_gei:
x.substitute(v0, v1)
for x in self.exprs_eqi:
x.substitute(v0, v1)
if not isinstance(v0, Expr):
self.aux.substitute(v0, v1)
self.inp.substitute(v0, v1)
self.oup.substitute(v0, v1)
self.auxi.substitute(v0, v1)
if iutil.check_meta_subs_criteria(v0, v1, self):
iutil.substitute(self.meta, v0, v1)
return self
@fcn_substitute
def substitute_whole(self, v0, v1):
"""Substitute variable v0 by v1 (v1 can be compound), in place"""
for x in self.exprs_ge:
x.substitute_whole(v0, v1)
for x in self.exprs_eq:
x.substitute_whole(v0, v1)
for x in self.exprs_gei:
x.substitute_whole(v0, v1)
for x in self.exprs_eqi:
x.substitute_whole(v0, v1)
if not isinstance(v0, Expr):
self.aux.substitute_whole(v0, v1)
self.inp.substitute_whole(v0, v1)
self.oup.substitute_whole(v0, v1)
self.auxi.substitute_whole(v0, v1)
if iutil.check_meta_subs_criteria(v0, v1, self):
iutil.substitute_whole(self.meta, v0, v1)
return self
@fcn_substitute
def substitute_aux(self, v0, v1):
"""Substitute variable v0 by v1 (v1 can be compound), and remove auxiliary v0, in place"""
for x in self.exprs_ge:
x.substitute(v0, v1)
for x in self.exprs_eq:
x.substitute(v0, v1)
for x in self.exprs_gei:
x.substitute(v0, v1)
for x in self.exprs_eqi:
x.substitute(v0, v1)
if not isinstance(v0, Expr):
self.aux -= v0
self.inp -= v0
self.oup -= v0
self.auxi -= v0
if iutil.check_meta_subs_criteria(v0, v1, self):
iutil.substitute(self.meta, v0, v1)
return self
def substituted_aux(self, *args, **kwargs):
"""Substitute variable v0 by v1 (v1 can be compound), and remove auxiliary v0, return result"""
r = self.copy()
r.substitute_aux(*args, **kwargs)
return r
def remove_present(self, v):
self.exprs_ge = [x for x in self.exprs_ge if not x.ispresent(v)]
self.exprs_eq = [x for x in self.exprs_eq if not x.ispresent(v)]
self.exprs_gei = [x for x in self.exprs_gei if not x.ispresent(v)]
self.exprs_eqi = [x for x in self.exprs_eqi if not x.ispresent(v)]
if isinstance(v, Comp):
self.aux -= v
self.inp -= v
self.oup -= v
self.auxi -= v
def remove_notpresent(self, v):
self.exprs_ge = [x for x in self.exprs_ge if x.ispresent(v)]
self.exprs_eq = [x for x in self.exprs_eq if x.ispresent(v)]
self.exprs_gei = [x for x in self.exprs_gei if x.ispresent(v)]
self.exprs_eqi = [x for x in self.exprs_eqi if x.ispresent(v)]
if isinstance(v, Comp):
self.aux = self.aux.inter(v)
self.inp = self.inp.inter(v)
self.oup = self.oup.inter(v)
self.auxi = self.auxi.inter(v)
def remove_notcontained(self, v):
t = self.allcomp() - v
if not t.isempty():
self.remove_present(t)
def remove_relax(self, v, sn = 1):
if sn < 0:
self.imp_flip()
if not self.remove_relax(v):
self.setempty()
return False
self.imp_flip()
return True
if any(x.ispresent(v) for x in self.exprs_eqi):
self.setuniverse()
return False
for x in self.exprs_gei:
if not x.try_remove(v, -1):
self.setuniverse()
return False
tlist = self.exprs_ge
self.exprs_ge = []
for x in tlist:
if x.try_remove(v, 1):
self.exprs_ge.append(x)
self.exprs_eq = [x for x in self.exprs_eq if not x.ispresent(v)]
if isinstance(v, Comp):
self.aux -= v
self.inp -= v
self.oup -= v
self.auxi -= v
return True
# def removed_constraints(self, v):
# """Return the region after the contraints in v are removed.
# """
# cs = self.copy()
# if not isinstance(v, Comp):
# cs.remove_relax(v)
# return cs
# bnet = cs.get_bayesnet(roots = v)
# bnet += v
# cs.remove_relax(v)
# return (cs & bnet.get_region()).simplified_quick()
def condition(self, b):
"""Condition on random variable b, in place"""
for x in self.exprs_ge:
x.condition(b)
for x in self.exprs_eq:
x.condition(b)
for x in self.exprs_gei:
x.condition(b)
for x in self.exprs_eqi:
x.condition(b)
return self
def conditioned(self, b):
"""Condition on random variable b, return result"""
r = self.copy()
r.condition(b)
return r
def symm_sort(self, terms):
"""Sort the random variables in terms assuming symmetry among those terms."""
for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi:
x.symm_sort(terms)
def find(self, *args):
return self.allcomp().find(*args)
def placeholder(*args):
r = Expr.zero()
for a in args:
if isinstance(a, Comp):
r += Expr.H(a) * 0
elif isinstance(a, Expr):
r += a * 0
return r >= 0
def empty_placeholder(*args):
r = Region.empty()
r.iand_norename(Region.placeholder(*args))
return r
def record_to(self, index):
for x in self.exprs_ge:
x.record_to(index)
for x in self.exprs_eq:
x.record_to(index)
index.record(self.aux)
index.record(self.inp)
index.record(self.oup)
for x in self.exprs_gei:
x.record_to(index)
for x in self.exprs_eqi:
x.record_to(index)
index.record(self.auxi)
def name_avoid(self, name0, regs = None):
index = IVarIndex()
self.record_to(index)
if regs is not None:
for r in regs:
r.record_to(index)
return index.name_avoid(name0)
# name1 = name0
# while index.get_index_name(name1) >= 0:
# name1 += PsiOpts.settings["rename_char"]
# return name1
def rename_avoid(self, reg, name0):
index = IVarIndex()
reg.record_to(index)
sindex = IVarIndex()
self.record_to(sindex)
name1 = name0
while (index.get_index_name(name1) >= 0
or (name1 != name0 and sindex.get_index_name(name1) >= 0)):
name1 += PsiOpts.settings["rename_char"]
if name1 != name0:
self.rename_var(name0, name1)
def aux_addprefix(self, pref = "@@"):
for i in range(len(self.aux.varlist)):
self.rename_var(self.aux.varlist[i].name,
pref + self.aux.varlist[i].name)
def aux_present(self):
return not self.getaux().isempty() or not self.getauxi().isempty()
def aux_clear(self):
self.aux = Comp.empty()
self.auxi = Comp.empty()
def getaux(self):
return self.aux.copy()
def getauxi(self):
return self.auxi.copy()
def getauxall(self):
return self.aux + self.auxi
def getauxs(self):
r = []
if not self.aux.isempty():
r.append((self.aux.copy(), True))
if not self.auxi.isempty():
r.append((self.auxi.copy(), False))
return r
def aux_avoid(self, reg, samesuffix = True):
if samesuffix:
self.aux_avoid_from(reg.allcomprv_noaux(), samesuffix = True)
reg.aux_avoid_from(self.allcomprv(), samesuffix = True)
else:
for a in reg.getauxi().varlist:
reg.rename_avoid(self, a.name)
for a in reg.getaux().varlist:
reg.rename_avoid(self, a.name)
for a in self.getauxi().varlist:
self.rename_avoid(reg, a.name)
for a in self.getaux().varlist:
self.rename_avoid(reg, a.name)
def aux_avoid_from(self, reg, samesuffix = True):
if not PsiOpts.settings["avoid_enabled"]:
return
if samesuffix:
if isinstance(reg, Region):
reg = reg.allcomprv()
reg = reg + self.allcomprv_noaux()
auxcomp = self.getauxall()
regindex = IVarIndex()
reg.record_to(regindex)
if not regindex.ispresent(auxcomp):
return
rename_char = PsiOpts.settings["rename_char"]
for rep in ["set", "add", "suffix"]:
for k in range(1, 512 if rep else 1000):
rdict = {}
rset = set()
bad = False
for a in auxcomp:
t = iutil.set_suffix_num(a.get_name(), k, rename_char, replace_mode = rep)
if t in rset:
bad = True
break
rdict[a.get_name()] = t
rset.add(t)
if not bad and not any(regindex.get_index_name(a) >= 0 for a in rset):
self.rename_map(rdict)
return
else:
for a in self.getaux().varlist:
self.rename_avoid(reg, a.name)
for a in self.getauxi().varlist:
self.rename_avoid(reg, a.name)
def iand_norename(self, other):
co = other
self.exprs_ge += [x.copy() for x in co.exprs_ge]
self.exprs_eq += [x.copy() for x in co.exprs_eq]
self.exprs_gei += [x.copy() for x in co.exprs_gei]
self.exprs_eqi += [x.copy() for x in co.exprs_eqi]
self.aux += co.aux
self.auxi += co.auxi
return self
def combine_meta(self, other):
if other.meta is None:
return
if self.meta is None:
self.meta = iutil.copy(other.meta)
return
for key, value in other.meta.items():
if key in self.meta:
if key == "rv_order":
self.meta[key] += value
else:
self.meta[key] = value
def __iand__(self, other):
if isinstance(other, bool):
if not other:
return Region.empty()
return self
other = iutil.ensure_region(other)
if (isinstance(other, RegionOp) or self.imp_present() or other.imp_present()
or (not PsiOpts.settings["prefer_expand"] and (self.aux_present() or other.aux_present()))):
return RegionOp.inter([self]) & other
if not self.aux_present() and not other.aux_present():
self.exprs_ge += [x.copy() for x in other.exprs_ge]
self.exprs_eq += [x.copy() for x in other.exprs_eq]
self.exprs_gei += [x.copy() for x in other.exprs_gei]
self.exprs_eqi += [x.copy() for x in other.exprs_eqi]
self.combine_meta(other)
return self
co = other.copy()
self.aux_avoid(co)
self.exprs_ge += [x.copy() for x in co.exprs_ge]
self.exprs_eq += [x.copy() for x in co.exprs_eq]
self.exprs_gei += [x.copy() for x in co.exprs_gei]
self.exprs_eqi += [x.copy() for x in co.exprs_eqi]
self.aux += co.aux
self.auxi += co.auxi
self.combine_meta(other)
return self
def __and__(self, other):
r = self.copy()
r &= other
return r
def __rand__(self, other):
r = self.copy()
r &= other
return r
def attach(self, other, rv_add = None):
selfrv = self.allcomprv_noaux()
if rv_add is not None:
selfrv += rv_add
otherrv = other.allcomprv_noaux()
self &= other
right = otherrv - selfrv
center = self.var_neighbors(right) - right
left = selfrv - center
# print(right)
# print(center)
# print(left)
if not left.isempty() and not right.isempty():
self &= markov(left, center, right)
return self
def __pow__(self, other):
if other <= 0:
return Region.universe()
r = self.copy()
for i in range(other - 1):
r &= self
return r
def __or__(self, other):
other = iutil.ensure_region(other)
if self.isempty():
return other.copy()
return RegionOp.union([self]) | other
def __ror__(self, other):
other = iutil.ensure_region(other)
if self.isempty():
return other.copy()
return RegionOp.union([self]) | other
def __ior__(self, other):
other = iutil.ensure_region(other)
if self.isempty():
return other.copy()
return RegionOp.union([self]) | other
def implicate(self, other, skip_simplify = False):
other = iutil.ensure_region(other)
co = other.copy()
if not skip_simplify and PsiOpts.settings["imp_simplify"]:
if co.imp_present():
co.simplify()
self.aux_avoid(co)
self.exprs_ge += [x.copy() for x in co.exprs_gei]
self.exprs_eq += [x.copy() for x in co.exprs_eqi]
self.exprs_gei += [x.copy() for x in co.exprs_ge]
self.exprs_eqi += [x.copy() for x in co.exprs_eq]
#self.aux += co.auxi
#self.auxi = co.aux + self.auxi
self.aux = co.auxi + self.aux
self.auxi += co.aux
return self
def implicated(self, other, skip_simplify = False):
other = iutil.ensure_region(other)
# if isinstance(other, RegionOp) or self.imp_present() or other.imp_present() or not other.aux.isempty():
if isinstance(other, RegionOp) or not self.auxi.isempty() or other.imp_present() or not other.aux.isempty():
return RegionOp.union([self]).implicated(other, skip_simplify)
r = self.copy()
r.implicate(other, skip_simplify)
return r
def __le__(self, other):
other = iutil.ensure_region(other)
return other.implicated(self)
def __ge__(self, other):
other = iutil.ensure_region(other)
return self.implicated(other)
def __rshift__(self, other):
other = iutil.ensure_region(other)
return other.implicated(self)
def __rrshift__(self, other):
other = iutil.ensure_region(other)
return self.implicated(other)
def __lshift__(self, other):
other = iutil.ensure_region(other)
return self.implicated(other)
def __rlshift__(self, other):
other = iutil.ensure_region(other)
return other.implicated(self)
def __eq__(self, other):
other = iutil.ensure_region(other)
#return self.implies(other) and other.implies(self)
return RegionOp.inter([self.implicated(other), other.implicated(self)])
def __ne__(self, other):
other = iutil.ensure_region(other)
return ~(RegionOp.inter([self.implicated(other), other.implicated(self)]))
def relax_term(self, term, gap):
self.simplify_quick()
for x in self.exprs_eq:
c = x.get_coeff(term)
if c != 0.0:
self.exprs_ge.append(x.copy())
self.exprs_ge.append(-x)
x.setzero()
for x in self.exprs_ge:
c = x.get_coeff(term)
if c > 0.0:
x.substitute(Expr.fromterm(term), Expr.fromterm(term) + gap)
elif c < 0.0:
x.substitute(Expr.fromterm(term), Expr.fromterm(term) - gap)
self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()]
def relax(self, w, gap):
"""Relax real variables in w by gap, in place"""
for (a, c) in w.terms:
if a.get_type() == TermType.REAL:
self.relax_term(a, gap)
return self
def relaxed(self, w, gap):
"""Relax real variables in w by gap, return result"""
r = self.copy()
r.relax(w, gap)
return r
def balance(self, v = None, w = None, skip_simplify = False):
if w is None:
w = self.allcomprv()
for x in self.exprs_ge + self.exprs_gei:
t = x.balanced(v, w, sn = 1)
if t is None:
t = Expr.zero()
x.copy_(t)
for x in self.exprs_eq:
if x.isbalanced(v):
continue
t = x.balanced(v, w, sn = 1)
if t is not None:
self.exprs_ge.append(t)
t = x.balanced(v, w, sn = -1)
if t is not None:
self.exprs_ge.append(-t)
x.setzero()
for x in self.exprs_eqi:
if x.isbalanced(v):
continue
t = x.balanced(v, w, sn = 1)
if t is not None:
self.exprs_gei.append(t)
t = x.balanced(v, w, sn = -1)
if t is not None:
self.exprs_gei.append(-t)
x.setzero()
self.exprs_ge = [x for x in self.exprs_ge if not x.iszero()]
self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()]
self.exprs_gei = [x for x in self.exprs_gei if not x.iszero()]
self.exprs_eqi = [x for x in self.exprs_eqi if not x.iszero()]
if not skip_simplify:
self.simplify()
return self
def balanced(self, v = None, skip_simplify = False):
r = self.copy()
r.balance(v, skip_simplify = True)
if not skip_simplify:
return r.simplified()
return r
def one_flipped(self, strict = False):
if len(self.exprs_eq) > 0 or len(self.exprs_ge) != 1:
return None
if strict:
c = self.exprs_ge[0].get_coeff(Term.eps())
r = self.exprs_ge[0].substituted(Expr.eps(), Expr.zero())
if c >= 0:
return r + Expr.eps() <= 0
else:
return r <= 0
else:
return self.exprs_ge[0] <= 0
def broken_present(self, w, flipped = True):
"""Convert region to intersection of individual constraints if they contain w"""
if self.imp_present():
return RegionOp.from_region(self).broken_present(w, flipped)
r = RegionOp.inter([])
cs = self.copy()
for x in cs.exprs_eq:
if x.ispresent(w):
r.regs.append((x == 0, True))
x.setzero()
for x in cs.exprs_ge:
if x.ispresent(w):
if flipped:
r.regs.append((~(x <= 0), True))
else:
r.regs.append((x >= 0, True))
x.setzero()
cs.exprs_eq = [x for x in cs.exprs_eq if not x.iszero()]
cs.exprs_ge = [x for x in cs.exprs_ge if not x.iszero()]
cs.aux = Comp.empty()
cs.auxi = Comp.empty()
r.regs.append((cs, True))
return r.exists(self.aux).forall(self.auxi)
def corners_optimum(self, w, sn):
"""Return union of regions corresponding to maximum/minimum of the real variable w"""
for x in self.exprs_eq:
if x.get_coeff(w.terms[0][0]) != 0:
return self.copy()
r = []
for i in range(len(self.exprs_ge)):
x = self.exprs_ge[i]
if x.get_coeff(w.terms[0][0]) * sn < 0:
cs2 = self.copy()
cs2.exprs_ge.pop(i)
cs2.exprs_eq.append(x.copy())
cs2.aux = Comp.empty()
cs2.auxi = Comp.empty()
r.append(cs2)
if len(r) == 0:
return Region.universe()
if len(r) == 1:
return r[0].exists(self.aux).forall(self.auxi)
return RegionOp.union(r).exists(self.aux).forall(self.auxi)
def corners_optimum_eq(self, w, sn):
"""Return union of regions corresponding to maximum/minimum of the real variable w"""
for x in self.exprs_eq:
if x.get_coeff(w.terms[0][0]) != 0:
return self.copy()
r = []
cs = self.copy()
cs.remove_present(w.terms[0][0].x[0])
for i in range(len(self.exprs_ge)):
x = self.exprs_ge[i]
if x.get_coeff(w.terms[0][0]) * sn < 0:
cs2 = cs.copy()
cs2.exprs_eqi.append(x.copy())
r.append(cs2)
if len(r) == 0:
return Region.universe()
if len(r) == 1:
return r[0]
return RegionOp.inter(r)
def corners(self, w):
"""Return union of regions corresponding to corner points of the real variables in w"""
terms = []
if isinstance(w, Expr):
for (a, c) in w.terms:
if a.get_type() == TermType.REAL:
terms.append(a)
else:
for w2 in w:
for (a, c) in w2.terms:
if a.get_type() == TermType.REAL:
terms.append(a)
n = len(terms)
cmat = []
cmatall = []
for x in self.exprs_eq:
coeff = [x.get_coeff(term) for term in terms]
cmat.append(coeff[:])
cmatall.append(coeff[:])
rank = numpy.linalg.matrix_rank(cmat)
if rank >= n:
return [self.copy()]
cs = self.copy()
cs.aux = Comp.empty()
ges = []
gec = []
for x in cs.exprs_ge:
coeff = [x.get_coeff(term) for term in terms]
cmatall.append(coeff[:])
if len([x for x in coeff if abs(x) > PsiOpts.settings["eps"]]) > 0:
ges.append(x.copy())
gec.append(coeff)
x.setzero()
cs.exprs_ge = [x for x in cs.exprs_ge if not x.iszero()]
rankall = numpy.linalg.matrix_rank(cmatall)
r = []
for comb in itertools.combinations(range(len(ges)), rankall - rank):
mat = cmat[:]
for i in comb:
mat.append(gec[i])
if numpy.linalg.matrix_rank(mat) >= rankall:
cs2 = cs.copy()
for i2 in range(len(ges)):
if i2 in comb:
cs2.exprs_eq.append(ges[i2].copy())
else:
cs2.exprs_ge.append(ges[i2].copy())
r.append(cs2)
return RegionOp.union(r).exists(self.aux)
def corners_value(self, w, ispolar = False, skip_discover = False, inf_value = 1e6):
"""Return the vertices and the facet list of the polytope with coordinates in the list w."""
if not skip_discover:
t = real_array("#TMPVAR", 0, len(w))
return self.discover([(a, b) for a, b in zip(t, w)]).corners_value(t, ispolar, True, inf_value)
cindex = IVarIndex()
self.record_to(cindex)
for a in w:
a.record_to(cindex)
prog = self.imp_flipped().init_prog(index = cindex)
g, f = prog.corners_value(ispolar = ispolar)
A = SparseMat(0)
Ab = []
for a in w:
c, c1 = prog.get_vec(a, sparse = True)
A.extend(c)
Ab.append(c1)
r = []
for b in g:
v = [sum([b[j + 1] * c for j, c in row], 0.0) for row in A.x]
if abs(b[0]) < 1e-11:
maxv = max(abs(x) for x in v)
v = [0.0] + [(x / maxv) * inf_value for x in v]
else:
v = [1.0 if b[0] > 0 else -1.0] + [x / abs(b[0]) + y for x, y in zip(v, Ab)]
r.append(v)
return (r, f)
def sign_present(self, term):
sn_present = [False] * 2
for x in self.exprs_ge:
c = x.get_coeff(term)
if c > 0.0:
sn_present[1] = True
elif c < 0.0:
sn_present[0] = True
for x in self.exprs_eq:
c = x.get_coeff(term)
if c != 0.0:
sn_present[0] = True
sn_present[1] = True
for x in self.exprs_gei:
c = x.get_coeff(term)
if c > 0.0:
sn_present[0] = True
elif c < 0.0:
sn_present[1] = True
for x in self.exprs_eqi:
c = x.get_coeff(term)
if c != 0.0:
sn_present[0] = True
sn_present[1] = True
return sn_present
def substitute_sign(self, v0, v1s):
v0term = v0.terms[0][0]
sn_present = [False] * 2
for x in self.exprs_eq:
if x.ispresent(v0):
if v1s:
self.exprs_ge.append(x.copy())
self.exprs_ge.append(-x)
x.setzero()
sn_present[0] = True
sn_present[1] = True
for x in self.exprs_eqi:
if x.ispresent(v0):
if v1s:
self.exprs_gei.append(x.copy())
self.exprs_gei.append(-x)
x.setzero()
sn_present[0] = True
sn_present[1] = True
for x in self.exprs_ge:
if x.ispresent(v0):
c = x.get_coeff(v0term)
if c > 0.0:
if v1s:
x.substitute(v0, v1s[1])
sn_present[1] = True
else:
if v1s:
x.substitute(v0, v1s[0])
sn_present[0] = True
for x in self.exprs_gei:
if x.ispresent(v0):
c = x.get_coeff(v0term)
if c > 0.0:
if v1s:
x.substitute(v0, v1s[0])
sn_present[0] = True
else:
if v1s:
x.substitute(v0, v1s[1])
sn_present[1] = True
if v1s:
self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()]
self.exprs_eqi = [x for x in self.exprs_eqi if not x.iszero()]
return sn_present
def substitute_duplicate(self, v0, v1s):
for l in [self.exprs_ge, self.exprs_eq, self.exprs_gei, self.exprs_eqi]:
olen = len(l)
for ix in range(olen):
x = l[ix]
if x.ispresent(v0):
for v1 in v1s:
l.append(x.substituted(v0, v1))
x.setzero()
self.exprs_ge = [x for x in self.exprs_ge if not x.iszero()]
self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()]
self.exprs_gei = [x for x in self.exprs_gei if not x.iszero()]
self.exprs_eqi = [x for x in self.exprs_eqi if not x.iszero()]
# def flattened_minmax(self, term, sgn, bds):
# sbds = self.get_lb_ub_eq(term)
# reg.eliminate_term(self)
def lowest_present(self, v, sn):
if not sn:
return None
if self.ispresent(v):
return self
return None
def flatten_regterm(self, term):
self.simplify_quick()
sn = term.sn
sn_present = [False] * 2
for x in self.exprs_ge:
c = x.get_coeff(term)
if c * sn > 0.0:
sn_present[1] = True
elif c * sn < 0.0:
sn_present[0] = True
for x in self.exprs_eq:
c = x.get_coeff(term)
if c != 0.0:
sn_present[0] = True
sn_present[1] = True
self.exprs_ge.append(x.copy())
self.exprs_ge.append(-x)
x.setzero()
for x in self.exprs_gei:
c = x.get_coeff(term)
if c * sn > 0.0:
sn_present[0] = True
elif c * sn < 0.0:
sn_present[1] = True
for x in self.exprs_eqi:
c = x.get_coeff(term)
if c != 0.0:
sn_present[0] = True
sn_present[1] = True
self.exprs_gei.append(x.copy())
self.exprs_gei.append(-x)
x.setzero()
self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()]
self.exprs_eqi = [x for x in self.exprs_eqi if not x.iszero()]
cs = self
if sn == 0:
sn_present[1] = False
sn_present[0] = True
if sn_present[1]:
tmpvar = Expr.real("FLAT_TMP_" + str(term))
for x in cs.exprs_ge:
c = x.get_coeff(term)
if c * sn > 0.0:
x.substitute(Expr.fromterm(term), tmpvar)
for x in cs.exprs_gei:
c = x.get_coeff(term)
if c * sn < 0.0:
x.substitute(Expr.fromterm(term), tmpvar)
reg2 = term.reg.copy()
reg2.substitute(Expr.fromterm(term), tmpvar)
cs.aux_avoid(reg2)
newindep = Expr.Ic(reg2.getauxi(), cs.allcomprv() - cs.getaux() - reg2.allcomprv(),
reg2.allcomprv_noaux()).simplified_quick()
if reg2.get_type() == RegionType.NORMAL:
reg2 = reg2.corners_optimum_eq(tmpvar, sn)
cs = reg2 & cs
if not newindep.iszero():
cs = cs.iand_norename((newindep == 0).imp_flipped())
cs.eliminate_quick(tmpvar)
if sn_present[0]:
newvar = Expr.real(str(term))
reg2 = term.reg.copy()
cs.aux_avoid(reg2)
newindep = Expr.Ic(reg2.getaux(), cs.allcomprv() - cs.getaux() - reg2.allcomprv(),
reg2.allcomprv_noaux()).simplified_quick()
if reg2.get_type() == RegionType.NORMAL:
reg2 = reg2.corners_optimum(Expr.fromterm(term), sn)
cs = cs.implicated(reg2)
#cs &= reg2.imp_flipped()
if not newindep.iszero():
cs = cs.iand_norename((newindep == 0).imp_flipped())
cs.substitute(Expr.fromterm(term), newvar)
return cs
def flatten_ivar(self, ivar):
cs = self
newvar = Comp([ivar.copy_noreg()])
reg2 = ivar.reg.copy()
cs.aux_avoid(reg2)
cs = cs.implicated(reg2, skip_simplify = True)
cs.substitute(Comp([ivar]), newvar)
if not ivar.reg_det:
newindep = Expr.Ic(reg2.getaux() + newvar, cs.allcomprv() - cs.getaux() - reg2.allcomprv(),
reg2.allcomprv_noaux() - newvar).simplified_quick()
if not newindep.iszero():
cs = cs.iand_norename((newindep == 0).imp_flipped())
return cs
def isregtermpresent(self):
for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi:
if x.isregtermpresent():
return True
return False
def numexpr(self):
return len(self.exprs_ge) + len(self.exprs_eq) + len(self.exprs_gei) + len(self.exprs_eqi)
def numterm(self):
return sum([len(x.terms) for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi])
def isplain(self):
if not self.aux.isempty():
return False
if self.isregtermpresent():
return False
return True
def regtermmap(self, cmap, recur):
for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi + [Expr.H(self.aux + self.auxi)]:
for (a, c) in x.terms:
rvs = a.allcomprv_shallow()
for b in rvs.varlist:
if b.reg is not None:
s = b.name
if not (s in cmap):
cmap[s] = b
if recur:
b.reg.regtermmap(cmap, recur)
if a.get_type() == TermType.REGION:
s = a.x[0].varlist[0].name
if not (s in cmap):
cmap[s] = a
if recur:
a.reg.regtermmap(cmap, recur)
def flatten(self):
verbose = PsiOpts.settings.get("verbose_flatten", False)
write_pf_enabled = PsiOpts.settings.get("proof_enabled", False)
cs = self
did = False
didall = False
regterms = {}
cs.regtermmap(regterms, False)
regterms_in = {}
for (name, term) in regterms.items():
term.reg.regtermmap(regterms_in, True)
for (name, term) in regterms.items():
if not(name in regterms_in):
if verbose:
print("========= flatten ========")
print(cs)
print("========= term ========")
print(term)
print("========= region ========")
print(term.reg)
if isinstance(term, IVar):
cs = cs.flatten_ivar(term)
else:
cs = cs.flatten_regterm(term)
did = True
didall = True
if verbose:
print("========= to ========")
print(cs)
break
if write_pf_enabled:
if didall:
pf = ProofObj.from_region(self, c = "Expanded definitions to")
PsiOpts.set_setting(proof_add = pf)
if did:
return cs.flatten()
return cs
def flatten_term(self, x, isimp = True):
cs = self
if isinstance(x, Comp):
for y in x.varlist:
cs = cs.flatten_ivar(y, isimp)
else:
for y, _ in x.terms:
cs = cs.flatten_regterm(y, isimp)
return cs
def flattened(self, *args, minmax_elim = False):
if not self.isregtermpresent() and len(args) == 0:
return self.copy()
cs = None
if not isinstance(cs, RegionOp):
cs = RegionOp.inter([self])
else:
cs = self.copy()
cs.flatten(minmax_elim = minmax_elim)
for x in args:
cs.flatten_term(x)
t = cs.tosimple()
if t is not None:
return t
return cs
def incorporate_tmp(self, x, tmplink):
if isinstance(x, Comp):
self &= (H(x) == tmplink)
elif isinstance(x, Expr):
self &= (x == tmplink)
return self
def incorporate_remove_tmp(self, tmplink):
self.remove_present(tmplink)
def to_cause_consequence(self):
return [(self.imp_flippedonly_noaux(), self.consonly(), self.auxi.copy())]
def and_cause_consequence(self, other, avoid = None, added_reg = None):
cs = self.copy()
other = other.copy()
cs.aux_avoid(other)
clist = other.to_cause_consequence()
for imp, cons, auxi in clist:
tcs = cs.tosimple()
if tcs is None:
return cs
cs = tcs
timp = imp.tosimple()
if timp is None:
continue
tcs = imp.exists(auxi)
tcs.implicate(cs)
# tcs = imp.copy()
# tcs.implicate(cs)
# tcs = tcs.exists(auxi)
# print(imp)
# print(cons)
# print(auxi)
# print("TCS")
# print(tcs)
# print(cons.allcomp().get_markers())
# print()
hint_aux_avoid = None
if avoid is not None:
hint_aux_avoid = []
for a in auxi:
hint_aux_avoid.append((a, avoid.copy()))
for rr in tcs.check_getaux_inplace_gen(hint_aux_avoid = hint_aux_avoid):
if iutil.signal_type(rr) == "":
# print(rr)
tcons = cons.copy()
Comp.substitute_list(tcons, rr)
cs &= tcons
if added_reg is not None:
added_reg &= tcons
return cs
def toregionop(self):
if not isinstance(self, RegionOp):
return RegionOp.inter([self])
else:
return self.copy()
def get_req_cons(self):
return self.toregionop().get_req_cons()
def incorporated(self, *args):
cs = self.toregionop()
# cargs = [a.copy() for a in args]
cargs = []
for a in args:
if isinstance(a, Comp):
for b in a:
cargs.append(b.copy())
else:
cargs.append(a.copy())
to_remove_aux = []
tmplink = Expr.real("#TMPLINK")
for i in range(len(cargs)):
x = cargs[i]
if isinstance(x, Region):
cs = cs.and_cause_consequence(x).toregionop()
else:
cs.flatten_term(x, isimp = False)
x_noreg = x.copy_noreg()
for j in range(i + 1, len(cargs)):
if isinstance(x, Comp) and isinstance(cargs[j], Region) and x not in cargs[j].allcomprv_noaux():
continue
cargs[j].substitute(x, x_noreg)
if not cs.ispresent(x_noreg):
#cs.eliminate(x_noreg, forall = True)
cs.incorporate_tmp(x_noreg, tmplink)
to_remove_aux.append(x_noreg)
if len(to_remove_aux):
cs.incorporate_remove_tmp(tmplink)
# print("AFTER INCORPORATED")
# print(cs)
return cs.simplified_quick()
def instantiated(self, *args):
return universe().incorporated(*args, self)
def flattened_self(self):
r = self.copy()
r = r.flatten()
return r.simplify_quick()
def tosimple(self):
if not self.auxi.isempty():
return None
return self.copy()
def tosimple_noaux(self):
if self.aux_present():
return None
return self.tosimple()
def tosimple_safe(self):
return self.tosimple()
def tonormal_safe(self):
return self.copy()
def level(self, mode = ""):
"""The level of a region in the linear entropy hierarchy.
"""
r = Level()
if not self.aux.isempty():
r = r.exists()
if not self.auxi.isempty():
r = r.forall()
return r
def region_level(self, mode = ""):
return RegionLevel(self.level(mode = mode))
def complexity(self):
return sum(x.complexity() for x in self.exprs_eq + self.exprs_ge + self.exprs_eqi + self.exprs_gei) + (len(self.aux) + len(self.auxi)) * 100
def sorting_priority(self):
return self.complexity()
def z3(self, vardict = None):
if z3 is None:
return None
if vardict is None:
vardict = iutil.z3_vardict(self.rvs, self.aux, self.auxi, self.reals)
# if not self.auxi.isempty():
# raise ValueError("Universally-quantified random variables are not supported for Z3. Use another solver.")
if self.imp_present():
r = z3.Or(z3.Not(self.imp_flippedonly_noaux().z3(vardict)), self.consonly().z3(vardict))
if not self.auxi.isempty():
r = z3.ForAll([x.z3(vardict) for x in self.auxi], r)
return r
r = [(x.z3(vardict) == 0) for x in self.exprs_eq] + [(x.z3(vardict) >= 0) for x in self.exprs_ge]
if len(r) == 1:
r = r[0]
else:
r = z3.And(r)
if not self.aux.isempty():
return z3.Exists([x.z3(vardict) for x in self.aux], r)
else:
return r
def check_z3(self):
truth = PsiOpts.settings["truth"]
if truth is not None:
with PsiOpts(truth = None):
return (truth >> self).check_z3()
indreg = self.get_indreg_checked()
if indreg is not None:
with PsiOpts(indreg_enabled = False):
return (indreg >> self).check_z3()
vardict = iutil.z3_vardict(self.rvs, self.aux, self.auxi, self.reals)
t = self.z3(vardict)
solver = vardict["#solver"]
solver.add(z3.Not(t))
res = solver.check()
return res == z3.unsat
def commonpart_extend(self, v, forbid_h = True):
ceps = PsiOpts.settings["eps"]
tvar = Expr.real("#TVAR")
did = False
didpos = False
toadd = []
todel = set()
for x in self.exprs_eq:
c = x.commonpart_coeff(v, forbid_h = forbid_h)
if c is None or numpy.isnan(c):
return None
if numpy.isinf(c):
return None
if abs(c) <= ceps:
continue
if x.isnonneg() and c > 0:
return None
did = True
didpos = True
toadd.append((x, c))
for x in self.exprs_ge:
c = x.commonpart_coeff(v, forbid_h = forbid_h)
if c is None or numpy.isnan(c):
return None
if numpy.isinf(c):
if c < 0:
return None
todel.add(x)
continue
if abs(c) <= ceps:
continue
if x.isnonpos() and c < 0:
return None
did = True
if c > 0:
didpos = True
toadd.append((x, c))
for x in self.exprs_eqi:
c = x.commonpart_coeff(v, forbid_h = forbid_h)
if c is None or numpy.isnan(c):
return None
if numpy.isinf(c):
return None
if abs(c) <= ceps:
continue
did = True
didpos = True
toadd.append((x, c))
for x in self.exprs_gei:
c = x.commonpart_coeff(v, forbid_h = forbid_h)
if c is None or numpy.isnan(c):
return None
if numpy.isinf(c):
return None
if abs(c) <= ceps:
continue
did = True
didpos = True
toadd.append((x, c))
if not didpos:
return None
self.exprs_eq = [x for x in self.exprs_eq if x not in todel]
self.exprs_ge = [x for x in self.exprs_ge if x not in todel]
self.exprs_eqi = [x for x in self.exprs_eqi if x not in todel]
self.exprs_gei = [x for x in self.exprs_gei if x not in todel]
for x, c in toadd:
x += tvar * c
self.exprs_ge.append(tvar)
self.eliminate(tvar)
return self
def var_neighbors(self, v):
r = v.copy()
for x in self.exprs_eq + self.exprs_ge + self.exprs_eqi + self.exprs_gei:
r += x.var_neighbors(v)
return r
def aux_strengthen(self, addrv = None, other_neighbors = None):
if self.aux.isempty():
return self
if self.imp_present():
cs = self.consonly()
cs.aux_strengthen(addrv)
self.cons_shallow_copy_(cs)
return self
allc = self.allcomprv()
if addrv is not None:
allc += addrv
toadd = Region.universe()
ag = []
acon = []
for a in self.aux:
v = self.var_neighbors(a)
if other_neighbors is not None:
v = v + other_neighbors.var_neighbors(a)
ag.append(v.copy())
acon.append(v.copy())
for it in range(len(self.aux)):
for i, a in enumerate(self.aux):
for j, b in enumerate(self.aux):
if acon[i].ispresent(b):
acon[i] += acon[j]
acon[j] += acon[i]
def recur(ca, prevv):
v = Comp.empty()
maxdeg = -1
maxdega = None
for i, a in enumerate(self.aux):
if ca.ispresent(a):
if not acon[i].super_of(ca):
recur(acon[i].inter(ca), None)
recur(ca - acon[i], None)
return
v += ag[i]
deg = len(ag[i]) + len(ag[i] - ca) * 50
if deg > maxdeg:
maxdeg = deg
maxdega = a
if maxdeg < 0:
return
# print(str(ca) + " " + str(v) + " " + str(allc-v) + " " + str(prevv))
if prevv is None or v != prevv:
b = allc - v
if not b.isempty():
toadd.exprs_eq.append(Expr.Ic(ca, b, v - ca))
recur(ca - maxdega, v)
recur(self.aux.copy(), None)
self.iand_norename(toadd)
return self
def aux_strengthen_old(self, addrv = None, other_neighbors = None):
if self.aux.isempty():
return self
if self.imp_present():
cs = self.consonly()
cs.aux_strengthen(addrv)
self.cons_shallow_copy_(cs)
return self
allc = self.allcomprv()
if addrv is not None:
allc += addrv
toadd = Region.universe()
allc_on = allc
if other_neighbors is not None:
allc_on += other_neighbors.allcomprv()
allc -= self.aux
allc_on -= self.aux
for a in self.aux:
v = self.var_neighbors(a)
if other_neighbors is not None:
v = v + other_neighbors.var_neighbors(a)
v = v.inter(allc_on)
b = allc - v
if not b.isempty():
toadd.exprs_eq.append(Expr.Ic(a, b, v - a))
allc += a
allc_on += a
self.iand_norename(toadd)
return self
def simplify_aux_commonpart(self, reg = None, minlen = 1, maxlen = 1, forbid_h = True):
if self.aux.isempty():
return self
if self.imp_present():
cs = self.consonly()
cs.simplify_aux_commonpart(reg, minlen, maxlen, forbid_h)
self.cons_shallow_copy_(cs)
return self
if reg is None:
reg = Region.universe()
did = False
taux = self.aux.copy()
taux2 = taux.copy()
tauxi = self.auxi.copy()
self.aux = Comp.empty()
self.auxi = Comp.empty()
for clen in range(minlen, maxlen + 1):
if clen == 2:
continue
if clen > len(taux):
break
for v in igen.subset(taux, minsize = clen):
if PsiOpts.is_timer_ended():
break
if clen == 1:
if self.commonpart_extend(v, forbid_h = forbid_h) is not None:
did = True
else:
for v2 in igen.partition(v, clen):
if PsiOpts.is_timer_ended():
break
if self.commonpart_extend(v2, forbid_h = forbid_h) is not None:
did = True
self.aux = taux2
self.auxi = tauxi
# if did:
# self.simplify_quick()
return self
def remove_realvar_sn(self, v, sn):
# v = v.allcomp()
v = v.terms[0][0]
for x in self.exprs_ge:
# if len(x) == 1 and x.terms[0][1] * sn > 0 and x.terms[0][0].get_type() == IVarType.REAL and x.allcomp() == v:
if x.get_coeff(v) * sn > 0:
x.setzero()
self.exprs_ge = [x for x in self.exprs_ge if not x.iszero()]
def istight(self, canon = False):
return all(x.istight(canon) for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi)
def tighten(self):
for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi:
x.tighten()
def add_meta_present(self, b, key, value):
for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi:
if x.ispresent(b):
x.add_meta(key, value)
def optimum(self, v, b, sn, name = None, reg_outer = None, assume_feasible = True, allow_reuse = False, quick = None, quick_outer = None, tighten = False):
"""Return the variable obtained from maximizing (sn=1)
or minimizing (sn=-1) the expression v over variables b (Comp, Expr or list)
"""
if reg_outer is not None:
if quick_outer is None:
quick_outer = quick
a0 = self.optimum(v, b, sn, name = name, reg_outer = None, assume_feasible = assume_feasible, allow_reuse = allow_reuse, quick = quick)
a1 = reg_outer.optimum(v, b, sn, name = name, reg_outer = None, assume_feasible = assume_feasible, allow_reuse = allow_reuse, quick = quick_outer)
a1.terms[0][0].substitute(a1.terms[0][0].x[0], a0.terms[0][0].x[0])
a0.terms[0][0].reg_outer = a1.terms[0][0].reg
if tighten:
a0.tighten()
return a0
tmpstr = ""
if name is not None:
tmpstr = name
else:
if sn > 0:
tmpstr = "max"
else:
tmpstr = "min"
tmpstr += str(iutil.hash_short(self))
tmpstr = tmpstr + "(" + str(v) + ")"
tmpvar = Expr.real(tmpstr)
if allow_reuse and v.size() == 1 and v.terms[0][0].get_type() == TermType.REAL:
coeff = v.terms[0][1]
if coeff < 0:
sn *= -1
cs = self.copy()
if b is not None:
b = Region.get_allcomp(b) - v.allcomp()
if quick is False:
cs.eliminate(b)
else:
cs.eliminate_quick(b)
if assume_feasible:
if cs.get_type() == RegionType.NORMAL:
cs.remove_realvar_sn(v, sn)
if not allow_reuse:
cs.substitute(Expr.fromterm(Term(v.terms[0][0].copy().x, Comp.empty())), tmpvar)
v = tmpvar
cs.add_meta_present(v, "pf_note", "at optimum")
return Expr.fromterm(Term(v.terms[0][0].copy().x, Comp.empty(), cs, sn)) * coeff
cs = self.copy()
toadd = Region.universe()
if sn > 0:
toadd.exprs_ge.append(v - tmpvar)
else:
toadd.exprs_ge.append(tmpvar - v)
toadd = toadd.flattened(minmax_elim = True)
cs = cs & toadd
# cs.iand_norename(toadd)
if quick is None:
quick = v.isrealvar() and v.allcomp().super_of(Region.get_allcomp(b))
return cs.optimum(tmpvar, b, sn, name = name, reg_outer = reg_outer, assume_feasible = assume_feasible, allow_reuse = True, quick = quick)
def maximum(self, expr, vs, reg_outer = None, **kwargs):
"""Return the variable obtained from maximizing the expression expr
over variables vs (Comp, Expr or list)
"""
return self.optimum(expr, vs, 1, reg_outer = reg_outer, **kwargs)
def minimum(self, expr, vs, reg_outer = None, **kwargs):
"""Return the variable obtained from minimizing the expression expr
over variables vs (Comp, Expr or list)
"""
return self.optimum(expr, vs, -1, reg_outer = reg_outer, **kwargs)
def init_prog(self, index = None, lptype = None, save_res = False, lp_bounded = None, dual_enabled = None, val_enabled = None, simplified_prog = False, discover_exprs = None, discover_init_model = None, discover_method = ""):
if index is None:
index = IVarIndex()
self.record_to(index)
prog = None
if lptype is None:
lptype = PsiOpts.settings["lptype"]
if PsiOpts.settings["lptype_H_if_proof"] and PsiOpts.settings["proof_enabled"]:
lptype = LinearProgType.H
if lp_bounded is None:
if save_res:
lp_bounded = True
if simplified_prog:
lp_bounded = False
save_res = False
val_enabled = False
if lptype == LinearProgType.H:
prog = LinearProg(index, lptype, lp_bounded = lp_bounded,
save_res = save_res, prereg = self, dual_enabled = dual_enabled, val_enabled = val_enabled)
elif lptype == LinearProgType.HC1BN or lptype == LinearProgType.HMIN:
bnet = self.get_bayesnet_imp(skip_simplify = True, add_hc = PsiOpts.settings["lp_bnet_hc"])
if PsiOpts.settings["lp_bnet_reverse"]:
bnet = bnet.reversed()
kindex = bnet.index.copy()
kindex.add_varindex(index)
prog = LinearProg(kindex, lptype, bnet, lp_bounded = lp_bounded,
save_res = save_res, prereg = self, dual_enabled = dual_enabled, val_enabled = val_enabled)
for x in self.exprs_gei:
prog.addExpr_ge0(x)
for x in self.exprs_eqi:
prog.addExpr_eq0(x)
if simplified_prog:
for x in igen.subset(index.comprv, minsize = 1):
prog.addExpr_eq0((Expr.H(x) * (1.0 + 1.0 / len(x)) * 0.1).add_meta("sim", True))
# prog.addExpr_eq0((Expr.H(x) / (1.0 + 1.0 / len(x)) * 0.1).add_meta("sim", True))
for x in index.comprv:
for y in igen.subset(index.comprv - x, minsize = 1):
if len(x) == len(y) and str(x) > str(y):
continue
prog.addExpr_eq0((Expr.I(x, y) * (1.0 + 1.0 / len(x + y)) * 0.1).add_meta("sim", True))
if discover_exprs is not None:
prog.dual_discover_exprs = discover_exprs
if discover_method == "l1":
prog.dual_discover_l1 = True
if discover_init_model is not None:
prog.dual_discover_init_model = discover_init_model
prog.finish(skip_ent_ineq = simplified_prog)
return prog
def init_simplified_prog(self, index = None):
r = None
with PsiOpts(lp_dual_form = True, proof_enabled = False):
r = self.imp_flipped().init_prog(index, simplified_prog = True)
return r
def get_basis(self, more_vars = None):
"""Get a basis of the entropy region of this region (may not be minimal).
"""
cs = self.consonly().imp_flipped()
index = IVarIndex()
cs.record_to(index)
if more_vars is not None:
more_vars.record_to(index)
return cs.init_prog(index).get_var_exprs()
def get_linear_dependency(self, vs):
"""Get the linear dependency between the terms.
"""
ceps = PsiOpts.settings["eps"]
cs = self.consonly().imp_flipped()
index = IVarIndex()
cs.record_to(index)
for v in vs:
v.record_to(index)
prog = cs.init_prog(index)
r = [None] * len(vs)
pos = [-1] * len(vs)
mat = None
for i, v in enumerate(vs):
a = prog.get_vec(v)
a = a[0] + [a[1]]
# print(a)
if all(abs(b) <= ceps for b in a):
r[i] = [0] * len(vs)
continue
if mat is None:
pos[i] = 0
mat = numpy.array([a])
else:
y = None
# print(mat.transpose())
y = numpy.linalg.lstsq(mat.transpose(), a, rcond=None)
# print(y)
res = y[1]
y = y[0]
if all(abs(b) <= ceps for b in res):
r[i] = [y[pos[j]] if pos[j] >= 0 else 0 for j in range(len(vs))]
else:
pos[i] = mat.shape[0]
mat = numpy.vstack([mat, a])
# try:
# print(mat)
# # y = numpy.linalg.solve(numpy.vstack([mat.transpose(), [1] * mat.shape[0]]), a + [1])
# y = numpy.linalg.lstsq(mat.transpose(), a)
# # y = numpy.linalg.solve(mat, a)
# print(y)
# print()
# except:
# pos[i] = mat.shape[0]
# mat = numpy.vstack([mat, a])
# else:
# r[i] = [y[pos[j]] if pos[j] >= 0 else 0 for j in range(len(vs))]
return r
def get_prog_region(self, toreal = None, toreal_only = False):
cs = self.consonly().imp_flipped()
index = IVarIndex()
cs.record_to(index)
r = cs.init_prog(index, lptype = LinearProgType.H).get_region(toreal, toreal_only)
return r
def get_extreme_rays(self):
cs = self.consonly().imp_flipped()
index = IVarIndex()
cs.record_to(index)
r = cs.init_prog(index, lptype = LinearProgType.H).get_extreme_rays()
return r
def implies_ineq_cons_hash(self, expr, sg):
chash = hash(expr)
if sg == ">=":
for x in self.exprs_ge:
if chash == hash(x):
return True
for x in self.exprs_eq:
if chash == hash(x):
return True
return False
def implies_ineq_cons_quick(self, expr, sg):
"""Return whether self implies expr >= 0 or expr == 0, without linear programming"""
if sg == "==" and expr.isnonneg():
sg = ">="
expr = -expr
if sg == ">=":
for x in self.exprs_ge:
d = (expr - x).simplified_quick()
if d.isnonneg():
return True
for x in self.exprs_eq:
d = (expr - x).simplified_quick()
if d.isnonneg():
return True
d = (expr + x).simplified_quick()
if d.isnonneg():
return True
return False
if sg == "==":
for x in self.exprs_eq:
d = (expr - x).simplified_quick()
if d.iszero():
return True
return False
return False
def implies_ineq_quick(self, expr, sg, bnet = None):
"""Return whether self implies expr >= 0 or expr == 0, without linear programming"""
expr = expr.simplified_quick(bnet = bnet)
if expr.iszero():
return True
if sg == "==" and expr.isnonneg():
sg = ">="
expr = -expr
if sg == ">=":
if expr.isnonneg():
return True
for x in self.exprs_gei:
d = (expr - x).simplified_quick(bnet = bnet)
if d.isnonneg():
return True
for x in self.exprs_eqi:
d = (expr - x).simplified_quick(bnet = bnet)
if d.isnonneg():
return True
d = (expr + x).simplified_quick(bnet = bnet)
if d.isnonneg():
return True
return False
if sg == "==":
for x in self.exprs_eqi:
d = (expr - x).simplified_quick(bnet = bnet)
if d.iszero():
return True
return False
return False
def implies_ineq_prog(self, index, progs, expr, sg, save_res = False, saved = False, allow_shortcut = True):
if saved == "both":
return (self.implies_ineq_prog(index, progs, expr, sg, save_res, saved = True, allow_shortcut = allow_shortcut)
and self.implies_ineq_prog(index, progs, expr, sg, save_res, saved = False, allow_shortcut = allow_shortcut))
#print("save_res = " + str(save_res))
if not saved and allow_shortcut and self.implies_ineq_quick(expr, sg):
return True
if len(progs) == 0:
progs.append(self.init_prog(index, save_res = save_res))
if progs[0].checkexpr(expr, sg, saved = saved):
return True
return False
def implies_impflipped_saved(self, other, index, progs):
verbose_subset = PsiOpts.settings.get("verbose_subset", False)
if verbose_subset:
print(self)
print(other)
for x in other.exprs_ge:
if not self.implies_ineq_prog(index, progs, x, ">=", save_res = True, saved = "both"):
if verbose_subset:
print(str(x) + " >= 0 FAIL")
return False
for x in other.exprs_eq:
if not self.implies_ineq_prog(index, progs, x, "==", save_res = True, saved = "both"):
if verbose_subset:
print(str(x) + " == 0 FAIL")
return False
if verbose_subset:
print("SUCCESS")
return True
def implies_saved(self, other, index, progs):
if index is None:
if len(progs):
index = progs[0].index
else:
index = IVarIndex()
self.record_to(index)
other.record_to(index)
self.imp_flip()
r = self.implies_impflipped_saved(other, index, progs)
self.imp_flip()
return r
def check_quick(self, bnet = None, skip_simplify = False):
"""Return whether implication is true"""
verbose_subset = PsiOpts.settings.get("verbose_subset", False)
cs = self
if not skip_simplify:
cs = self.simplified_quick(zero_group = 2)
cs.split()
if verbose_subset:
print(cs)
for x in cs.exprs_ge:
if not cs.implies_ineq_quick(x, ">=", bnet = bnet):
if verbose_subset:
print(str(x) + " >= 0 FAIL")
return False
for x in cs.exprs_eq:
if not cs.implies_ineq_quick(x, "==", bnet = bnet):
if verbose_subset:
print(str(x) + " == 0 FAIL")
return False
if verbose_subset:
print("SUCCESS")
return True
def check_plain(self, skip_simplify = False, quick = False, bnet = None):
"""Return whether implication is true"""
if quick:
return self.check_quick(bnet = bnet, skip_simplify = skip_simplify)
verbose_subset = PsiOpts.settings.get("verbose_subset", False)
allow_shortcut = not (PsiOpts.settings["proof_noskip"] and PsiOpts.settings["proof_enabled"])
cs = self
if not skip_simplify:
if allow_shortcut:
cs = self.simplified_quick(zero_group = 2)
if not PsiOpts.settings.get("lp_zero_group", False) or (PsiOpts.settings["lp_no_zero_group_if_proof"] and PsiOpts.settings["proof_enabled"]):
cs.split()
# print(PsiOpts.settings["proof_noskip"])
# print(PsiOpts.settings["proof_enabled"])
# print(cs)
index = IVarIndex()
cs.record_to(index)
progs = []
if verbose_subset:
print(cs)
for x in cs.exprs_ge:
if not cs.implies_ineq_prog(index, progs, x, ">=", allow_shortcut = allow_shortcut):
if verbose_subset:
print(str(x) + " >= 0 FAIL")
return False
for x in cs.exprs_eq:
if not cs.implies_ineq_prog(index, progs, x, "==", allow_shortcut = allow_shortcut):
if verbose_subset:
print(str(x) + " == 0 FAIL")
return False
if verbose_subset:
print("SUCCESS")
return True
def check_dual(self, optimal_terms = True, force_H = False, mul_coeff = True, skip_simplify = True, existing_only = False):
truth = PsiOpts.settings["truth"]
if truth is not None:
with PsiOpts(truth = None):
return (truth >> self).check_dual(optimal_terms=optimal_terms, force_H=force_H, mul_coeff=mul_coeff, skip_simplify=skip_simplify, existing_only=existing_only)
indreg = self.get_indreg_checked()
if indreg is not None:
with PsiOpts(indreg_enabled = False):
return (indreg >> self).check_dual(optimal_terms=optimal_terms, force_H=force_H, mul_coeff=mul_coeff, skip_simplify=skip_simplify, existing_only=existing_only)
if self.aux_present():
rr = self.check_getaux()
if rr is None:
return None
subdict = Comp.substitute_list_to_dict(rr, multi = True)
cs = self.copy()
cs = cs.substituted_dict_union(subdict).noaux()
return cs.check_dual(optimal_terms=optimal_terms, force_H=force_H, mul_coeff=mul_coeff, skip_simplify=skip_simplify, existing_only=existing_only)
verbose_subset = PsiOpts.settings.get("verbose_subset", False)
ceps = PsiOpts.settings["eps_lp"]
# cs = self.imp_flippedonly_noaux()
cs = self
r = Region.universe()
index = IVarIndex()
cs.record_to(index)
progs = []
ctx = None
if force_H:
ctx = PsiOpts(lptype = "H")
else:
ctx = contextlib.nullcontext()
with PsiOpts(lp_dual_form = True), ctx:
progs.append(self.init_prog(index, dual_enabled = True))
for x in self.exprs_ge + self.exprs_eq + [-y for y in self.exprs_eq]:
if not cs.implies_ineq_prog(index, progs, x, ">=", allow_shortcut = False):
if verbose_subset:
print(str(x) + " >= 0 FAIL")
return None
if optimal_terms:
t = progs[0].get_dual_region_terms(x)
if t is not None:
r.iand_norename(t)
else:
r.iand_norename(x >= 0)
elif existing_only:
for y in self.exprs_gei:
t = progs[0].get_dual_expr(y)
if t is not None and abs(t) > ceps:
if mul_coeff:
r.exprs_ge.append(y * t)
else:
r.exprs_ge.append(y)
for y in self.exprs_eqi:
t = progs[0].get_dual_expr(y)
if t is not None and abs(t) > ceps:
if mul_coeff:
r.exprs_eq.append(y * t)
else:
r.exprs_eq.append(y)
else:
csum = Expr.zero()
cdual = progs[0].get_dual_region(mul_coeff = mul_coeff, omit_trivial = not skip_simplify, tosum=csum)
# print(cdual)
# print(csum)
if cdual is not None:
r.iand_norename(cdual)
csum = x - csum
csum.simplify()
if not csum.isnonneg():
r.iand_norename(csum >= 0)
else:
r.iand_norename(x >= 0)
# v, prog = cs.minimum(x).solve_prog("dual")
# if v is None or v < -ceps or prog is None:
# return None
# cdual = prog.get_dual_region()
# if cdual is not None:
# r &= cdual
if not skip_simplify:
r = r.simplified_quick()
return r
def min_assumption(self, optimal_terms = True, force_H = False, mul_coeff = False, skip_simplify = False, existing_only = False):
"""Get a smaller set of assumptions needed to prove an implication.
WARNING: If not used with parameter "force_H=True", may omit conditional independence and functional dependencies.
"""
return self.check_dual(optimal_terms=optimal_terms, force_H=force_H, mul_coeff=mul_coeff, skip_simplify=skip_simplify, existing_only=existing_only)
def ic_list(self, v):
r = []
for x in self.exprs_ge:
if x.isnonpos():
continue
for (a, c) in x.terms:
if not a.isic2():
continue
if a.x[0].ispresent(v):
r.append((a.x[1].copy(), c))
if a.x[1].ispresent(v):
r.append((a.x[0].copy(), c))
return r
def ic_list_similarity(self, vl, wl):
r = 0
for (va, vc) in vl:
for (wa, wc) in wl:
if vc * wc < 0:
continue
if va == wa:
r += 2
elif va.super_of(wa) or wa.super_of(va):
r += 1
return r
def get_hc(self):
r = Expr.zero()
for x in self.exprs_ge:
if x.isnonpos():
for a, c in x.terms:
if a.ishc():
cx = a.x[0] - a.z
if not cx.isempty():
r.terms.append((Term.Hc(cx, a.z), 1.0))
for x in self.exprs_eq:
if x.isnonpos() or x.isnonneg():
for a, c in x.terms:
if a.ishc():
cx = a.x[0] - a.z
if not cx.isempty():
r.terms.append((Term.Hc(cx, a.z), 1.0))
return r
def get_var_avoid(self, a):
r = None
for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi:
t = x.get_var_avoid(a)
if r is None:
r = t
elif t is not None:
r = r.inter(t)
return r
def get_aux_avoid_list(self):
r = []
caux = self.getaux()
for a in caux:
t = self.get_var_avoid(a)
if t is not None:
r.append((a, t))
return r
def check_getaux_inplace(self, must_include = None, single_include = None, hint_pair = None, hint_aux = None, hint_aux_avoid = None, max_iter = None, leaveone = None):
"""Return whether implication is true, with auxiliary search result"""
if hint_aux_avoid is None:
hint_aux_avoid = []
hint_aux_avoid = hint_aux_avoid + self.get_aux_avoid_list()
for rr in self.check_getaux_inplace_gen(must_include = must_include,
single_include = single_include, hint_pair = hint_pair,
hint_aux = hint_aux, hint_aux_avoid = hint_aux_avoid,
max_iter = max_iter, leaveone = leaveone):
if iutil.signal_type(rr) == "":
return rr
return None
def check_getaux_inplace_gen(self, must_include = None, single_include = None, hint_pair = None, hint_aux = None, hint_aux_avoid = None, max_iter = None, leaveone = None):
"""Generator that yields all auxiliary search result"""
if leaveone is None:
leaveone = PsiOpts.settings["auxsearch_leaveone"]
write_pf_enabled = PsiOpts.settings.get("proof_enabled", False)
write_pf_repeat_claim = PsiOpts.settings.get("proof_repeat_implicant", False)
leaveone_add_ineq = PsiOpts.settings["auxsearch_leaveone_add_ineq"]
if self.aux.isempty():
if len(self.exprs_eq) == 0 and len(self.exprs_ge) == 0:
yield []
return
if not leaveone:
oproof = None
proof_stepped_in = False
if write_pf_enabled:
oproof = PsiOpts.get_proof().copy()
pcase = PsiOpts.get_proof().get_case()
if not pcase.isuniverse():
pf = ProofObj.from_region(self if write_pf_repeat_claim else self.consonly(), c = ["Case ", pcase.copy()])
PsiOpts.set_setting(proof_step_in = pf)
pf = PsiOpts.get_proof()
proof_stepped_in = True
desc_more = self.get_meta("pf_note_case")
if desc_more is not None:
pf.insert_step(ProofObj.from_region(pf.claim, c = desc_more), 0)
pf.claim = None
pf.desc += [":"]
cres = self.check_plain(skip_simplify = True)
if proof_stepped_in:
PsiOpts.set_setting(proof_step_out = True)
if cres:
yield []
else:
if oproof is not None:
PsiOpts.set_proof(oproof)
return
verbose = PsiOpts.settings.get("verbose_auxsearch", False)
verbose_step = PsiOpts.settings.get("verbose_auxsearch_step", False)
verbose_result = PsiOpts.settings.get("verbose_auxsearch_result", False)
verbose_cache = PsiOpts.settings.get("verbose_auxsearch_cache", False)
verbose_step_cached = PsiOpts.settings.get("verbose_auxsearch_step_cached", False)
as_generator = True
if must_include is None:
must_include = Comp.empty()
noncircular = PsiOpts.settings["imp_noncircular"]
noncircular_allaux = PsiOpts.settings["imp_noncircular_allaux"]
#noncircular_skipfail = True
forall_multiuse = PsiOpts.settings["forall_multiuse"]
forall_multiuse_numsave = PsiOpts.settings["forall_multiuse_numsave"]
auxsearch_local = PsiOpts.settings["auxsearch_local"]
save_res = auxsearch_local
if max_iter is None:
max_iter = PsiOpts.settings["auxsearch_max_iter"]
lpcost = 100
maxcost = max_iter * lpcost
curcost = [0]
cs = self.copy()
index = IVarIndex()
#cs.record_to(index)
progs = []
#ccomp = must_include + cs.auxi + (index.comprv - cs.auxi - must_include)
ccomp = must_include + cs.auxi + (cs.allcomprv() - cs.aux - cs.auxi - must_include)
index.record(ccomp)
index.record(cs.allcompreal())
auxcomp = cs.aux.copy()
auxcond = Comp.empty()
for a in auxcomp.varlist:
b = Comp([a])
if self.imp_ispresent(b):
auxcond += b
auxcomp = auxcond + auxcomp
n = auxcomp.size()
n_cond = auxcond.size()
m = ccomp.size()
#m_flip = cs.auxi.size()
clist = collections.deque()
clist_hashset = set()
flipflag = 0
auxiclist = [cs.ic_list(a) for a in auxcomp]
cs_flipped = cs.imp_flipped()
ciclist = [cs_flipped.ic_list(a) for a in ccomp]
cvisflag = 0
for i in range(n):
maxj = -1
maxval = 0
for j in range(m):
t = cs.ic_list_similarity(auxiclist[i], ciclist[j])
if t > maxval:
maxval = t
maxj = j
if maxj >= 0:
flipflag |= 1 << (m * i + maxj)
cvisflag |= 1 << maxj
for i in range(n):
if (flipflag >> (m * i)) & ((1 << m) - 1) == 0:
flipflag |= (((1 << (must_include + cs.auxi).size()) - 1) & ~cvisflag) << (m * i)
mustflag = 0
for i in range(n):
#flipflag += ((1 << (must_include + cs.auxi).size()) - 1) << (m * i)
mustflag += ((1 << must_include.size()) - 1) << (m * i)
singleflag = 0
if single_include is not None:
for j in range(m):
if single_include.ispresent(ccomp[j]):
singleflag += 1 << j
#m_flip = 0
# mp2 = (1 << m)
comppair = [-1 for j in range(m)]
auxpair = [-1 for i in range(n)]
compside = [0 for j in range(m)]
auxside = [0 for i in range(n)]
if hint_pair is not None:
for (a, b) in hint_pair:
ai = -1
bi = -1
for i in range(n):
if auxcomp.varlist[i] == a.varlist[0]:
ai = i
auxside[i] = 1
elif auxcomp.varlist[i] == b.varlist[0]:
bi = i
auxside[i] = -1
if ai >= 0 and bi >= 0:
auxpair[ai] = bi
auxpair[bi] = ai
ai = -1
bi = -1
for i in range(m):
if ccomp.varlist[i] == a.varlist[0]:
ai = i
compside[i] = 1
elif ccomp.varlist[i] == b.varlist[0]:
bi = i
compside[i] = -1
if ai >= 0 and bi >= 0:
comppair[ai] = bi
comppair[bi] = ai
setflag = 0
if hint_aux is not None:
for taux, tc in hint_aux:
pair_allowed = taux.get_marker_key("incpair") is not None
tcmask = 0
tcmask_pair = 0
for j in range(m):
if tc.ispresent(ccomp[j]):
tcmask |= 1 << j
if pair_allowed and comppair[j] < 0:
tcmask_pair |= 1 << j
elif pair_allowed and comppair[j] >= 0 and tc.ispresent(ccomp[comppair[j]]):
tcmask_pair |= 1 << j
for i in range(n):
if taux.ispresent(auxcomp[i]):
setflag |= tcmask << (m * i)
elif pair_allowed and auxpair[i] >= 0 and taux.ispresent(auxcomp[auxpair[i]]):
setflag |= tcmask_pair << (m * i)
auxlist = [Comp.empty() for i in range(n)]
auxflag = [0 for i in range(n)]
eqs = []
for x in cs.exprs_ge:
eqs.append((x.copy(), ">="))
for x in cs.exprs_eq:
eqs.append((x.copy(), "=="))
eqvs_range = n * 2 + 1
eqvs = [[] for i in range(eqvs_range)]
eqvsid = [[] for i in range(eqvs_range)]
eqvpresflag = [[] for i in range(eqvs_range)]
eqvleaveok = [[] for i in range(eqvs_range)]
eqvs_emptyid = n * 2
for (x, sg) in eqs:
maxi = -1
mini = 100000000
presflag = 0
for i in range(n):
if x.ispresent(Comp([auxcomp.varlist[i]])):
maxi = max(maxi, i)
mini = min(mini, i)
presflag |= 1 << i
ii = eqvs_emptyid
if maxi >= 0:
if mini == maxi:
ii = maxi * 2
else:
ii = maxi * 2 + 1
eqvsid[ii].append(len(eqvs[ii]))
eqvs[ii].append((x, sg))
eqvpresflag[ii].append(presflag)
eqvleaveok[ii].append(sg == ">=" and not x.isnonpos() and not x.ispresent(auxcond))
eqvsns = [[] for i in range(eqvs_range)]
for i in range(eqvs_range):
for j in range(len(eqvs[i])):
x, sg = eqvs[i][j]
if sg == "==":
eqvsns[i].append(0)
else:
csn = None
for i2 in range(n):
if eqvpresflag[i][j] & (1 << i2) != 0:
tsn = x.get_sign(auxcomp.varlist[i2])
if tsn == 0 or ((csn is not None) and csn != tsn):
csn = 0
break
else:
csn = tsn
eqvsns[i].append(csn if csn is not None else 0)
eqvsncache = [[] for i in range(n * 2 + 1)]
eqvflagcache = [[] for i in range(n * 2 + 1)]
for i in range(eqvs_range):
eqvsncache[i] = [[[], []] for j in range(len(eqvs[i]))]
eqvflagcache[i] = [{} for j in range(len(eqvs[i]))]
auxsetflag = [(setflag >> (m * i)) & ((1 << m) - 1) for i in range(n)]
auxavoidflag = [0 for i in range(n)]
if hint_aux_avoid is not None:
for taux, tc in hint_aux_avoid:
pair_allowed = taux.get_marker_key("incpair") is not None
tcmask = 0
tcmask_pair = 0
for j in range(m):
if tc.ispresent(ccomp[j]):
tcmask |= 1 << j
if pair_allowed and comppair[j] < 0:
tcmask_pair |= 1 << j
elif pair_allowed and comppair[j] >= 0 and tc.ispresent(ccomp[comppair[j]]):
tcmask_pair |= 1 << j
for i in range(n):
if taux.ispresent(auxcomp[i]):
auxavoidflag[i] |= tcmask
elif pair_allowed and auxpair[i] >= 0 and taux.ispresent(auxcomp[auxpair[i]]):
auxavoidflag[i] |= tcmask_pair
if False:
for i in range(n):
t = self.get_var_avoid(auxcomp[i])
if t is not None:
auxavoidflag[i] |= ccomp.get_mask(t)
for i in range(n):
auxavoidflag[i] &= ~auxsetflag[i]
avoidflag = sum(auxavoidflag[i] << (m * i) for i in range(n))
flipflag &= ~avoidflag
disjoint_ids = [-1 for i in range(n)]
symm_ids = [-1 for i in range(n)]
nonsubset_ids = [-1 for i in range(n)]
nonempty_is = [False for i in range(n)]
symm_nonempty_ns = [0] * n
for i in range(n):
if auxcomp.varlist[i].markers is None:
continue
cdict = {v: w for v, w in auxcomp.varlist[i].markers}
disjoint_ids[i] = cdict.get("disjoint", -1)
symm_ids[i] = cdict.get("symm", -1)
nonsubset_ids[i] = cdict.get("nonsubset", -1)
nonempty_is[i] = "nonempty" in cdict
symm_nonempty_ns[i] = cdict.get("symm_nonempty", 0)
for i in range(n):
if symm_nonempty_ns[i] > 0:
nsymm = 0
for i2 in range(i + 1, n):
if symm_ids[i] == symm_ids[i2]:
nsymm += 1
if nsymm < symm_nonempty_ns[i]:
nonempty_is[i] = True
#print("NONSUBSET " + "; ".join(str(auxcomp[i]) for i in range(n) if nonsubset_ids[i] >= 0))
fcns = cs_flipped.get_hc()
fcns_mask = []
for a, c in fcns.terms:
fcns_mask.append((index.get_mask(a.x[0]), index.get_mask(a.z)))
if verbose:
print("========= aux search ========")
print(cs.imp_flippedonly())
print("========= subset of =========")
#print(co)
for i in range(n * 2 + 1):
hs = ""
if i == n * 2:
hs = "NONE"
elif i % 2 == 0:
hs = str(auxcomp.varlist[i // 2])
else:
hs = "<=" + str(auxcomp.varlist[i // 2])
for j in range(len(eqvs[i])):
eqvsnstr = ""
if i < n * 2:
eqvsnstr = " " + ("amb" if eqvsns[i][j] == 0
else "inc" if eqvsns[i][j] == 1 else "dec")
if leaveone and eqvleaveok[i][j]:
eqvsnstr += " " + "leaveok"
print(iutil.strpad(hs, 8, ": " + str(eqvs[i][j][0]) + " " + str(eqvs[i][j][1]) + " 0" + eqvsnstr))
print("========= variables =========")
print(ccomp)
print("========= auxiliary =========")
#print(auxcomp)
print(str(auxcond) + " ; " + str(auxcomp - auxcond))
if len(fcns_mask) > 0:
print("========= functions =========")
for x, z in fcns_mask:
print(str(ccomp.from_mask(x)) + " <- " + str(ccomp.from_mask(z)))
if hint_pair is not None:
print("========= pairing =========")
for i in range(n):
if auxpair[i] > i:
print(str(auxcomp.varlist[i]) + " <-> " + str(auxcomp.varlist[auxpair[i]]))
for i in range(m):
if comppair[i] > i:
print(str(ccomp.varlist[i]) + " <-> " + str(ccomp.varlist[comppair[i]]))
print("========= initial =========")
for i in range(n):
cflag = (flipflag >> (m * i)) & ((1 << m) - 1)
csetflag = auxsetflag[i]
cavoidflag = auxavoidflag[i]
ccor = Comp.empty()
cset = Comp.empty()
cavoid = Comp.empty()
for j in range(m):
if cflag & (1 << j) != 0:
ccor += ccomp[j]
if csetflag & (1 << j) != 0:
cset += ccomp[j]
if cavoidflag & (1 << j) != 0:
cavoid += ccomp[j]
print(str(auxcomp.varlist[i]) + " : " + str(ccor)
+ (" Fix: " + str(cset) if not cset.isempty() else "")
+ (" Avoid: " + str(cavoid) if not cavoid.isempty() else ""))
if self.aux.isempty() and write_pf_enabled and leaveone:
oproof = PsiOpts.get_proof().copy()
pf = ProofObj.from_region(self if write_pf_repeat_claim else self.consonly(), c = "Claim:")
PsiOpts.set_setting(proof_step_in = pf)
pf = PsiOpts.get_proof()
# print(self)
curleave = None
for i in range(n * 2 + 1):
for j in range(len(eqvs[i])):
cs = self.noaux()
cs.exprs_ge = []
cs.exprs_eq = []
if eqvs[i][j][1] == "==":
cs.exprs_eq.append(eqvs[i][j][0])
else:
cs.exprs_ge.append(eqvs[i][j][0])
tres = cs.check_plain(skip_simplify = True)
if not tres:
if not eqvleaveok[i][j] or curleave is not None:
PsiOpts.set_setting(proof_step_out = True)
PsiOpts.set_proof(oproof)
return
else:
curleave = eqvs[i][j][0]
PsiOpts.set_setting(proof_step_out = True)
pcase = PsiOpts.get_proof().get_case()
if curleave is not None or not pcase.isuniverse():
tcase = pcase.copy()
if curleave is not None:
tcase &= (curleave >= 0)
PsiOpts.get_proof().set_case(pcase & (curleave <= 0))
pf.desc = ["Case ", tcase]
desc_more = self.get_meta("pf_note_case")
if desc_more is not None:
# pf.desc += [", ", "\n"] + desc_more
pf.insert_step(ProofObj.from_region(pf.claim, c = desc_more), 0)
pf.claim = None
pf.desc += [":"]
# pf2 = ProofObj.from_region(None, c = ["Remaining case: ", curleave <= 0])
# PsiOpts.set_setting(proof_add = pf2)
# print("WRITE PF LEAVEONE " + str(curleave))
if curleave is None:
yield []
else:
yield ("leaveone", [], -curleave)
return
if len(eqs) == 0:
#print(bin(setflag))
#print(bin(avoidflag))
#print("m=" + str(m) + " n=" + str(n))
#print("ccomp=" + str(ccomp) + " auxcomp=" + str(auxcomp))
#for a in auxcomp.varlist:
# print(str(a) + " " + str(a.markers))
mleft = [j for j in range(m * n) if setflag & (1 << j) == 0 and avoidflag & (1 << j) == 0]
def check_nocond_recur(i, size, allflag):
#print(str(i) + " " + str(size) + " " + bin(allflag))
if i == n:
if mustflag != 0 and mustflag & allflag == 0:
return
rr = [(auxcomp[i2].copy(), auxlist[i2].copy()) for i2 in range(n)]
yield rr
return
mlefti = [j - m * i for j in mleft if j >= m * i and j < m * (i + 1)]
symm_break = -1
if symm_ids[i] >= 0:
for i2 in range(i):
if symm_ids[i] == symm_ids[i2]:
symm_break = max(symm_break, auxflag[i2])
if disjoint_ids[i] >= 0:
for i2 in range(i):
if disjoint_ids[i] == disjoint_ids[i2]:
mlefti = [j for j in mlefti if auxflag[i2] & (1 << j) == 0]
sizelb = max(size - sum(j >= m * (i + 1) for j in mleft), 0)
sizeub = min(min(len(mlefti), m), size)
if sizelb > sizeub:
return
for tsize in range(sizelb, sizeub + 1):
for comb in itertools.combinations(list(reversed(mlefti)), tsize):
curcost[0] += 1
if PsiOpts.is_timer_ended() or (maxcost > 0 and curcost[0] >= maxcost):
return
auxflag[i] = sum(1 << j for j in comb) | auxsetflag[i]
if auxflag[i] < symm_break:
break
if nonempty_is[i] and auxflag[i] == 0:
continue
if any(auxflag[i] | z == auxflag[i] and auxflag[i] & x != 0 for x, z in fcns_mask):
continue
if nonsubset_ids[i] >= 0:
tbad = False
for i2 in range(i):
if nonsubset_ids[i] == nonsubset_ids[i2]:
if auxflag[i] | auxflag[i2] == auxflag[i] or auxflag[i] | auxflag[i2] == auxflag[i2]:
tbad = True
break
if tbad:
continue
auxlist[i] = ccomp.from_mask(auxflag[i])
#print(str(i) + " " + str(m) + " " + str(ccomp) + " " + bin(auxflag[i]) + " " + str(auxlist[i]))
for rr in check_nocond_recur(i + 1, size - len(comb), allflag | (auxflag[i] << (m * i))):
yield rr
#print("START")
for tsize in range(len(mleft) + 1):
for rr in check_nocond_recur(0, tsize, 0):
#print("; ".join(str(v) + ":" + str(w) for v, w in rr))
yield rr
#print("END")
return
auxcache = [{} for i in range(n)]
if verbose:
print("========= progress: =========")
leaveone_static = None
# *********** Check conditions that does not depend on aux ***********
if n_cond == 0:
for (x, sg) in eqvs[eqvs_emptyid]:
if not cs.implies_ineq_prog(index, progs, x, sg, save_res = save_res):
if leaveone and sg == ">=":
if verbose_step:
print(" F LO " + str(x) + " " + sg + " 0")
leaveone = False
leaveone_static = -x
else:
if verbose_step:
print(" F " + str(x) + " " + sg + " 0")
return None
allflagcache = [{} for i in range(n + 1)]
cursizepass = 0
cs = Region.universe()
cs_added = Region.universe()
#self.imp_only_copy_to(cs_added)
condflagadded = {}
condflagadded_true = collections.deque()
flagcache = [set() for i in range(n + 1)]
maxprogress = [-1, 0, -1]
flipflaglen = 0
numfail = [0]
numclear = [0]
auxavoidflag_orig = list(auxavoidflag)
def clear_cache(mini, maxi):
if verbose_cache:
print("========= cache clear: " + str(mini) + " - " + str(maxi) + " =========")
progs[:] = []
auxcache[mini:maxi] = [{} for i in range(mini, maxi)]
for a in flagcache:
a.clear()
for i in range(mini, maxi):
auxavoidflag[i] = auxavoidflag_orig[i]
for i in range(mini * 2, eqvs_range):
for j in range(len(eqvs[i])):
if eqvpresflag[i][j] & ((1 << maxi) - (1 << mini)) != 0:
eqvsncache[i][j] = [[], []]
eqvflagcache[i][j] = {}
# eqvsncache[i] = [[[], []] for j in range(len(eqvs[i]))]
# eqvflagcache[i] = [{} for j in range(len(eqvs[i]))]
def build_region(i, allflag, allownew, cflag = 0, add_ineq = None):
numclear[0] += 1
cs_added_changed = False
prev_csstr = cs.tostring(tosort = True)
if add_ineq is not None:
prev_csaddedstr = cs_added.tostring(tosort = True)
cs_added.exprs_gei.append(add_ineq)
cs_added.simplify_quick(zero_group = 2)
#cs_added.split()
if prev_csaddedstr != cs_added.tostring(tosort = True):
cs_added_changed = True
clist.appendleft(cflag)
while len(clist) > forall_multiuse_numsave:
clist.pop()
if verbose_step:
print("========= leave one added =========")
print(cs_added.imp_flipped())
print("==================")
cs_added.imp_only_copy_to(cs)
elif i >= n_cond and (forall_multiuse or leaveone):
csnew = Region.universe()
if not (allflag in condflagadded):
self.imp_only_copy_to(csnew)
for i3 in range(i, n_cond):
csnew.remove_present(Comp([auxcomp.varlist[i3]]))
for i3 in range(i):
csnew.substitute(Comp([auxcomp.varlist[i3]]), auxlist[i3])
if allownew:
condflagadded[allflag] = True
prev_csaddedstr = cs_added.tostring(tosort = True)
cs_added.iand_norename(csnew)
cs_added.simplify_quick(zero_group = 2)
#cs_added.split()
if prev_csaddedstr != cs_added.tostring(tosort = True):
cs_added_changed = True
condflagadded_true.appendleft(allflag)
while len(condflagadded_true) > forall_multiuse_numsave:
condflagadded_true.pop()
if verbose_step:
print("========= forall added =========")
print(cs_added.imp_flipped())
print("==================")
cs_added.imp_only_copy_to(cs)
if not allownew:
cs.iand_norename(csnew)
else:
self.imp_only_copy_to(cs)
for i3 in range(i, n_cond):
cs.remove_present(Comp([auxcomp.varlist[i3]]))
for i3 in range(i):
cs.substitute(Comp([auxcomp.varlist[i3]]), auxlist[i3])
if cs_added_changed or prev_csstr != cs.tostring(tosort = True):
if verbose_cache:
print("========= cleared =========")
print(prev_csstr)
print("========= to =========")
print(cs.tostring(tosort = True))
maxi = n
if forall_multiuse and not cs_added_changed:
maxi = n_cond
if noncircular:
if noncircular_allaux:
clear_cache(n_cond, maxi)
else:
clear_cache(min(1, n_cond), maxi)
else:
clear_cache(0, maxi)
else:
if verbose_cache:
print("========= not cleared =========")
print(prev_csstr)
#build_region(0, 0, True)
build_region(0, 0, leaveone)
def is_marker_sat(i):
if nonempty_is[i] and auxflag[i] == 0:
return False
if any(auxflag[i] | z == auxflag[i] and auxflag[i] & x != 0 for x, z in fcns_mask):
return False
if nonsubset_ids[i] >= 0:
for i2 in range(i):
if nonsubset_ids[i] == nonsubset_ids[i2]:
if auxflag[i] | auxflag[i2] == auxflag[i] or auxflag[i] | auxflag[i2] == auxflag[i2]:
return False
if symm_ids[i] >= 0:
for i2 in range(i):
if symm_ids[i] == symm_ids[i2]:
if auxflag[i] < auxflag[i2]:
return False
if disjoint_ids[i] >= 0:
tbad = False
for i2 in range(i):
if disjoint_ids[i] == disjoint_ids[i2]:
if auxflag[i] & auxflag[i2] != 0:
return False
return True
# *********** Sandwich procedure ***********
dosandwich = PsiOpts.settings["auxsearch_sandwich"]
dosandwich_inc = PsiOpts.settings["auxsearch_sandwich_inc"]
if dosandwich:
if verbose:
print("========== sandwich =========")
for i in range(n * 2 + 1):
if i == eqvs_emptyid:
continue
for j in range(len(eqvs[i])):
if leaveone and eqvleaveok[i][j]:
continue
presflag = eqvpresflag[i][j]
for sg in ([1, -1] if eqvs[i][j][1] == "==" else [1]):
x = eqvs[i][j][0] * sg
sns = [0] * n
bad = False
for i2 in range(n):
if presflag & (1 << i2):
sns[i2] = x.get_sign(auxcomp.varlist[i2])
if sns[i2] == 0:
bad = True
break
if not dosandwich_inc and sns[i2] > 0:
bad = True
break
if bad:
continue
for ix in range(-1, n):
if sns[ix] == 0:
continue
xt = x.copy()
for i2 in range(n):
if i2 == ix:
continue
cmask = 0
if sns[i2] < 0:
cmask = (setflag >> (m * i2)) & ((1 << m) - 1)
else:
cmask = ~(avoidflag >> (m * i2)) & ((1 << m) - 1)
# print(avoidflag)
# print(cmask)
xt.substitute(Comp([auxcomp.varlist[i2]]), ccomp.from_mask(cmask))
if ix == -1:
tres = cs.implies_ineq_prog(index, progs, xt, ">=", save_res = save_res, saved = "both")
if verbose_step:
print(str(x) + " : " + str(xt) + " >= 0 " + str(tres))
if not tres:
return None
continue
cmask0 = (setflag >> (m * ix)) & ((1 << m) - 1)
cmask1 = ~(avoidflag >> (m * ix)) & ((1 << m) - 1)
cmaskc = cmask1 & ~cmask0
for k in range(m):
if cmaskc & (1 << k):
cmask = 0
if sns[ix] < 0:
cmask = cmask0 | (1 << k)
else:
cmask = cmask1 & ~(1 << k)
xt2 = xt.copy()
xt2.substitute(Comp([auxcomp.varlist[ix]]), ccomp.from_mask(cmask))
tres = cs.implies_ineq_prog(index, progs, xt2, ">=", save_res = save_res, saved = "both")
if verbose_step:
print(str(x) + " : " + str(xt2) + " >= 0 " + str(tres))
if not tres:
if sns[ix] < 0:
avoidflag |= 1 << (m * ix + k)
auxavoidflag[ix] |= 1 << k
else:
setflag |= 1 << (m * ix + k)
auxsetflag[ix] |= 1 << k
flipflag &= ~avoidflag
auxavoidflag_orig = list(auxavoidflag)
if verbose:
print("======= after sandwich ======")
for i in range(n):
cflag = (flipflag >> (m * i)) & ((1 << m) - 1)
csetflag = auxsetflag[i]
cavoidflag = auxavoidflag[i]
ccor = Comp.empty()
cset = Comp.empty()
cavoid = Comp.empty()
for j in range(m):
if cflag & (1 << j) != 0:
ccor += ccomp[j]
if csetflag & (1 << j) != 0:
cset += ccomp[j]
if cavoidflag & (1 << j) != 0:
cavoid += ccomp[j]
print(str(auxcomp.varlist[i]) + " : " + str(ccor)
+ (" Fix: " + str(cset) if not cset.isempty() else "")
+ (" Avoid: " + str(cavoid) if not cavoid.isempty() else ""))
print("=============================")
def check_local(i0, allflag, leave_id = None):
cflag = ((flipflag | setflag) >> (m * i0)) << (m * i0)
mleft = [j for j in range(m * i0, m * n) if setflag & (1 << j) == 0]
#mleft = mleft[::-1]
while True:
cleave_id = leave_id
for i in range(i0, n):
auxflag[i] = (cflag >> (m * i)) & ((1 << m) - 1)
auxlist[i] = ccomp.from_mask(auxflag[i])
cres = True
bad = False
for i in range(i0, n):
if not is_marker_sat(i):
bad = True
break
if bad:
cres = False
if cres:
for i2 in range(i0 * 2, n * 2):
i2r = -1
isone = False
if i2 != eqvs_emptyid:
i2r = i2 // 2
isone = (i2 % 2 == 0)
auxflagi = auxflag[i2r]
if isone and (auxflagi in auxcache[i2r]):
cres = auxcache[i2r][auxflagi]
else:
for ieqid in range(len(eqvs[i2])):
ieq = eqvsid[i2][ieqid]
if cleave_id == (i2, ieq):
continue
(x, sg) = eqvs[i2][ieq]
x2 = x.copy()
if isone:
x2.substitute(Comp([auxcomp.varlist[i2r]]), auxlist[i2r])
else:
for i3 in range(i2r + 1):
x2.substitute(Comp([auxcomp.varlist[i3]]), auxlist[i3])
curcost[0] += lpcost
if PsiOpts.is_timer_ended() or (maxcost > 0 and curcost[0] >= maxcost):
return False
tres = cs.implies_ineq_prog(index, progs, x2, sg, save_res = save_res)
#print(iutil.strpad("; ".join([str(auxlist[i3]) for i3 in range(i)]),
# 26, " LO#" + str(numfail[0]), 8, " " + str(x) + " " + sg + " 0 " + str(tres)))
if not tres:
if isone and eqvsns[i2][ieq] < 0 and (not leaveone or not eqvleaveok[i2][ieq]) and iutil.bitcount(auxflagi & ~auxsetflag[i]) == 1:
auxavoidflag[i2r] |= auxflagi & ~auxsetflag[i]
if verbose_step:
print(iutil.strpad(" L AVOID " + str(auxcomp.varlist[i2r]),
12, " : " + str(ccomp.from_mask(auxavoidflag[i2r]))))
if leaveone and cleave_id is None and eqvleaveok[i2][ieq]:
cleave_id = (i2, ieq)
if verbose_step:
print(iutil.strpad("; ".join([str(auxlist[i3]) for i3 in range(n)]),
26, " LSET#" + str(numfail[0]), 12, " " + str(x) + " " + sg + " 0"))
else:
if verbose_step:
numfail[0] += 1
print(iutil.strpad("; ".join([str(auxlist[i3]) for i3 in range(n)]),
26, " LO#" + str(numfail[0]), 12, " " + str(x) + " " + sg + " 0"))
eqvsid[i2].pop(ieqid)
eqvsid[i2].insert(0, ieq)
flagcache[i2r + 1].add(cflag & ((1 << ((i2r + 1) * m)) - 1))
#print("FCA " + str(i2r + 1) + " " + bin(cflag & ((1 << ((i2r + 1) * m)) - 1)))
cres = False
break
if isone and not leaveone:
auxcache[i2r][auxflagi] = cres
if not cres:
break
if cres:
if as_generator:
if leaveone and cleave_id is not None:
x2 = -eqvs[cleave_id[0]][cleave_id[1]][0]
for i3 in range(n):
x2.substitute(Comp([auxcomp.varlist[i3]]), auxlist[i3])
#print("BUILD " + str(x2))
x2.simplify_quick()
if not x2.isnonneg():
if leaveone_add_ineq:
build_region(n_cond, allflag, False, cflag = cflag, add_ineq = x2)
yield ("leaveone", x2.copy())
else:
yield allflag | cflag
else:
return True
flagcache[n].add(cflag)
def check_local_recur(i, kflag, sizelb, sizeub, kleave_id):
csizelb = max(sizelb - sum(j >= m * (i + 1) for j in mleft), 0)
csizeub = sizeub
#print(str(i) + " " + str(csizelb) + " " + str(csizeub))
for tsize in range(csizelb, csizeub + 1):
mlefti = [j - m * i for j in mleft if j >= m * i and j < m * (i + 1)]
mlefti = [j for j in mlefti if auxavoidflag[i] & (1 << j) == 0]
cflagi = ((cflag >> (m * i)) & ((1 << m) - 1))
mleftimustflag = cflagi & auxavoidflag[i]
ttsize = tsize - iutil.bitcount(mleftimustflag)
if ttsize > len(mlefti):
break
if ttsize < 0:
continue
for comb in itertools.combinations(mlefti, ttsize):
curcost[0] += 1
if PsiOpts.is_timer_ended() or (maxcost > 0 and curcost[0] >= maxcost):
return False
auxflag[i] = sum(1 << j for j in comb) ^ cflagi
auxflag[i] &= ~auxavoidflag[i]
if not auxcache[i].get(auxflag[i], True):
continue
if not is_marker_sat(i):
continue
auxlist[i] = ccomp.from_mask(auxflag[i])
ckflag = kflag | (auxflag[i] << (m * i))
ckleave_id = kleave_id
if i == n - 1 and mustflag != 0 and mustflag & (allflag | ckflag) == 0:
continue
if ckflag in flagcache[i + 1]:
continue
cres = True
for i2 in range(i * 2, (i + 1) * 2):
i2r = -1
isone = False
if i2 != eqvs_emptyid:
i2r = i2 // 2
isone = (i2 % 2 == 0)
auxflagi = auxflag[i2r]
for ieqid in range(len(eqvs[i2])):
ieq = eqvsid[i2][ieqid]
(x, sg) = eqvs[i2][ieq]
x2 = x.copy()
if isone:
x2.substitute(Comp([auxcomp.varlist[i2r]]), auxlist[i2r])
else:
for i3 in range(i2r + 1):
x2.substitute(Comp([auxcomp.varlist[i3]]), auxlist[i3])
curcost[0] += lpcost
if PsiOpts.is_timer_ended() or (maxcost > 0 and curcost[0] >= maxcost):
return False
tres = cs.implies_ineq_prog(index, progs, x2, sg, save_res = save_res, saved = True)
if not tres:
if isone and eqvsns[i2][ieq] < 0 and (not leaveone or not eqvleaveok[i2][ieq]) and iutil.bitcount(auxflagi & ~auxsetflag[i]) == 1:
auxavoidflag[i2r] |= auxflagi & ~auxsetflag[i]
if verbose_step:
print(iutil.strpad(" T AVOID " + str(auxcomp.varlist[i2r]),
12, " : " + str(ccomp.from_mask(auxavoidflag[i2r]))))
if leaveone and ckleave_id is None and eqvleaveok[i2][ieq]:
#if verbose_step and verbose_step_cached:
# print(iutil.strpad("; ".join([str(auxlist[i3]) for i3 in range(i2r + 1)]),
# 26, " TSET=" + str(tsize) + ",#" + str(numfail[0]), 12, " " + str(x) + " " + sg + " 0"))
ckleave_id = (i2, ieq)
else:
if verbose_step and verbose_step_cached:
numfail[0] += 1
print(iutil.strpad("; ".join([str(auxlist[i3]) for i3 in range(i2r + 1)]),
26, " TO=" + str(tsize) + ",#" + str(numfail[0]), 12, " " + str(x) + " " + sg + " 0"))
eqvsid[i2].pop(ieqid)
eqvsid[i2].insert(0, ieq)
flagcache[i2r + 1].add(ckflag)
#print("FCB " + str(i2r + 1) + " " + bin(ckflag))
cres = False
break
if isone and not cres and not leaveone:
auxcache[i2r][auxflagi] = cres
if not cres:
break
if not cres:
continue
if i == n - 1:
return True
if check_local_recur(i + 1, ckflag, sizelb - len(comb), sizeub - len(comb), ckleave_id):
return True
if PsiOpts.is_timer_ended() or (maxcost > 0 and curcost[0] >= maxcost):
return False
return False
sizeseq = [0, 1, 2, 4]
sizeseq = [s for s in sizeseq if s < len(mleft)] + [len(mleft)]
found = False
for si in range(1, len(sizeseq)):
if si == 0:
tcflag = cflag
cflag = 0
if check_local_recur(i0, 0, 1, 1, None):
found = True
cflag = sum(auxflag[i] << (m * i) for i in range(i0, n))
break
cflag = tcflag
else:
if check_local_recur(i0, 0, sizeseq[si - 1] + 1, sizeseq[si], None):
found = True
cflag = sum(auxflag[i] << (m * i) for i in range(i0, n))
break
if not found:
return False
def check_recur(i, size, stepsize, allflag):
if i == n and mustflag != 0 and mustflag & allflag == 0:
return False
if allflag in allflagcache[i]:
return False
cprogress = i * (n * 2 + 2)
if cprogress > maxprogress[0]:
maxprogress[0] = cprogress
maxprogress[1] = allflag | (flipflag & ~((1 << (m * i)) - 1))
maxprogress[2] = i
if not noncircular and i == n_cond:
if verbose_cache:
print("========= cache clear: circ, suff # " + str(numclear[0]) + " =========")
build_region(n_cond, allflag, False)
i2lb = 0
i2ub = 0
if noncircular:
if i > 0:
i2lb = (i - 1) * 2
i2ub = i * 2
else:
if i >= n_cond:
i2lb = 0
if i > n_cond:
i2lb = (i - 1) * 2
i2ub = i * 2
for i2 in range(i2lb, i2ub):
cres = True
i2r = -1
isone = False
if i2 != eqvs_emptyid:
i2r = i2 // 2
isone = (i2 % 2 == 0)
auxflagi = auxflag[i2r]
if isone and (auxflagi in auxcache[i2r]):
cres = auxcache[i2r][auxflagi]
else:
for ieqid in range(len(eqvs[i2])):
ieq = eqvsid[i2][ieqid]
(x, sg) = eqvs[i2][ieq]
x2 = x.copy()
if isone:
x2.substitute(Comp([auxcomp.varlist[i2r]]), auxlist[i2r])
else:
for i3 in range(i2r + 1):
x2.substitute(Comp([auxcomp.varlist[i3]]), auxlist[i3])
auxflagpres = 0
auxflagprescn = 0
for i3 in range(i2r + 1):
if eqvpresflag[i2][ieq] & (1 << i3) != 0:
auxflagpres |= auxflag[i3] << (m * auxflagprescn)
auxflagprescn += 1
tres = None
computed = False
eqvsn = eqvsns[i2][ieq]
if eqvsn == 0:
tres = eqvflagcache[i2][ieq].get(auxflagpres, None)
else:
eqvsn = (eqvsn + 1) // 2
for f in eqvsncache[i2][ieq][eqvsn]:
if auxflagpres == f | auxflagpres:
tres = (eqvsn == 1)
break
if tres is None:
for f in eqvsncache[i2][ieq][1 - eqvsn]:
if f == f | auxflagpres:
tres = (eqvsn == 0)
break
if tres is None:
curcost[0] += lpcost
if PsiOpts.is_timer_ended() or (maxcost > 0 and curcost[0] >= maxcost):
return False
tres = cs.implies_ineq_prog(index, progs, x2, sg, save_res = save_res)
computed = True
eqvsn = eqvsns[i2][ieq]
if eqvsn == 0:
eqvflagcache[i2][ieq][auxflagpres] = tres
else:
eqvsncache[i2][ieq][1 if tres else 0].append(auxflagpres)
if not tres:
if verbose_step and (verbose_step_cached or computed):
numfail[0] += 1
print(iutil.strpad("; ".join([str(auxlist[i3]) for i3 in range(i)]),
26, " S=" + str(cursizepass) + ",T=" + str(stepsize)
+ ",L=" + str(flipflaglen) + ",#" + str(numfail[0]), 18, " " + str(x) + " " + sg + " 0"
+ ("" if computed else " (Ca)")))
eqvsid[i2].pop(ieqid)
eqvsid[i2].insert(0, ieq)
cres = False
break
if isone:
auxcache[i2r][auxflagi] = cres
if not cres:
allflagcache[i][allflag] = True
return False
cprogress = i * (n * 2 + 2) + i2 + 1
if cprogress > maxprogress[0]:
maxprogress[0] = cprogress
maxprogress[1] = allflag | (flipflag & ~((1 << (m * i)) - 1))
maxprogress[2] = i
if i == n_cond and n_cond > 0:
if noncircular:
if verbose_cache:
print("========= cache clear: nonc, checkempty # " + str(numclear[0]) + " =========")
build_region(i, allflag, True)
for ieqid in range(len(eqvs[eqvs_emptyid])):
ieq = eqvsid[eqvs_emptyid][ieqid]
(x, sg) = eqvs[eqvs_emptyid][ieq]
curcost[0] += lpcost
if PsiOpts.is_timer_ended() or (maxcost > 0 and curcost[0] >= maxcost):
return False
if not cs.implies_ineq_prog(index, progs, x, sg, save_res = save_res):
if verbose_step:
numfail[0] += 1
print(iutil.strpad("; ".join([str(auxlist[i3]) for i3 in range(i)]),
26, " S=" + str(cursizepass) + ",T=" + str(stepsize)
+ ",L=" + str(flipflaglen) + ",#" + str(numfail[0]), 18, " " + str(x) + " " + sg + " 0"))
allflagcache[i][allflag] = True
eqvsid[eqvs_emptyid].pop(ieqid)
eqvsid[eqvs_emptyid].insert(0, ieq)
return False
if i == n:
if as_generator:
yield allflag
return
else:
return True
if i == n_cond and auxsearch_local:
if as_generator:
for rr in check_local(i, allflag):
yield rr
return
else:
return check_local(i, allflag)
cflipflag = (flipflag >> (m * i)) & ((1 << m) - 1)
if i >= flipflaglen - 1:
i2 = auxpair[i]
if i2 >= 0 and i2 < i:
cflipflag = auxflag[i2]
for j in range(m):
j2 = comppair[j]
if j2 >= 0:
if (auxflag[i2] & (1 << j) != 0) and (auxflag[i2] & (1 << j2) == 0):
cflipflag &= ~(1 << j)
csetflag = (setflag >> (m * i)) & ((1 << m) - 1)
mleft = [j for j in range(m) if csetflag & (1 << j) == 0]
#sizelb = max(0, size - m * (n - i - 1))
sizelb = 0
sizeub = min(min(size, len(mleft)), stepsize)
pnumclear = -1
for tsize in range(sizelb, sizeub + 1):
for comb in itertools.combinations(mleft, tsize):
curcost[0] += 1
if PsiOpts.is_timer_ended() or (maxcost > 0 and curcost[0] >= maxcost):
return False
#print(tsize, comb)
tflag = 0
for j in comb:
tflag += (1 << j)
auxlist[i] = Comp.empty()
auxflag[i] = 0
for j in range(m):
if (csetflag & (1 << j) != 0) or (cflipflag & (1 << j) != 0) != (tflag & (1 << j) != 0):
auxflag[i] += (1 << j)
auxlist[i].varlist.append(ccomp.varlist[j])
if (auxflag[i] & singleflag) != 0 and iutil.bitcount(auxflag[i]) > 1:
continue
if i < n_cond:
pass
else:
if (auxflag[i] in auxcache[i]) and not auxcache[i]:
continue
if noncircular and i < n_cond and numclear[0] != pnumclear:
if verbose_cache:
print("========= cache clear: nonc, inc # " + str(numclear[0]) + " =========")
if noncircular_allaux:
build_region(0, 0, False)
else:
build_region(i, allflag, False)
pnumclear = numclear[0]
recur = check_recur(i+1, size - tsize, stepsize, allflag + (auxflag[i] << (m * i)))
if as_generator:
for rr in recur:
yield rr
else:
if recur:
return True
if PsiOpts.is_timer_ended() or (maxcost > 0 and curcost[0] >= maxcost):
return False
return False
res_hashset = set()
maxsize = m * n - iutil.bitcount(setflag)
size = 0
stepsize = 0
while True:
cursizepass = size
prevprogress = maxprogress[0]
recur = check_recur(0, size, stepsize, 0)
if not as_generator:
recur = [True] if recur else []
for rr in recur:
if verbose or verbose_result:
print("========= success cost " + str(curcost[0]) + "/" + str(maxcost) + " =========")
#print("========= final region =========")
#print(cs.imp_flipped())
print("========== aux =========")
for i in range(n):
print(iutil.strpad(str(auxcomp.varlist[i]), 6, ": " + str(auxlist[i])))
namelist = [auxcomp.varlist[i].name for i in range(n)]
res = []
for i in range(n):
i2 = namelist.index(self.aux.varlist[i].name)
cval = auxlist[i2].copy()
if forall_multiuse and i2 < n_cond and len(condflagadded_true) > 0:
cval = [ccomp.from_mask(x >> (m * i2)).copy() for x in condflagadded_true]
if len(cval) == 1:
cval = cval[0]
if i2 >= n_cond and len(clist) > 0:
cval = [cval] + [ccomp.from_mask(x >> (m * i2)).copy() for x in clist]
if len(cval) == 1:
cval = cval[0]
res.append((Comp([self.aux.varlist[i].copy()]), cval))
if as_generator:
res_hash = hash(iutil.list_tostr_std(res))
if not (res_hash in res_hashset):
if iutil.signal_type(rr) == "leaveone":
yield ("leaveone", res, rr[1])
elif leaveone_static is not None:
yield ("leaveone", res, leaveone_static.copy())
else:
yield res
res_hashset.add(res_hash)
else:
return res
if size >= maxsize:
break
if n_cond == 0 and auxsearch_local:
break
flipflag = maxprogress[1]
flipflaglen = maxprogress[2]
if prevprogress != maxprogress[0]:
size = 1
else:
if size < 2:
size += 1
else:
size *= 2
if size >= maxsize:
size = maxsize
stepsize = m
else:
clen = max(flipflaglen, 1)
stepsize = (size + clen - 1) // clen
if PsiOpts.is_timer_ended() or (maxcost > 0 and curcost[0] >= maxcost):
yield ("max_iter_reached", )
return None
def add_sfrl_imp(self, x, y, gap = None, noaux = True, name = None):
ccomp = self.allcomprv() - self.aux
if name is None:
name = self.name_avoid(y.tostring(add_bracket = True) + "%" + x.tostring(add_bracket = True))
newvar = Comp.rv(name)
self.exprs_gei.append(-Expr.I(x, newvar))
self.exprs_gei.append(-Expr.Hc(y, x + newvar))
others = ccomp - x - y
if not others.isempty():
self.exprs_gei.append(-Expr.Ic(newvar, others, x + y))
if gap is not None:
if not isinstance(gap, Expr):
gap = Expr.const(gap)
self.exprs_gei.append(gap.copy() - Expr.Ic(x, newvar, y))
if not noaux:
self.auxi += newvar
return newvar
@fcn_list_to_list
def add_sfrl(self, x, y, gap = None, noaux = True, name = None):
self.imp_flip()
r = self.add_sfrl_imp(x, y, gap, noaux, name)
self.imp_flip()
return r
def add_esfrl_imp(self, x, y, gap = None, noaux = True):
if x.super_of(y):
return Comp.empty(), y.copy()
if x.isempty():
return y.copy(), Comp.empty()
ccomp = self.allcomprv() - self.aux
newvar = Comp.rv(self.name_avoid(y.tostring(add_bracket = True) + "%" + x.tostring(add_bracket = True)))
newvark = Comp.rv(self.name_avoid(y.tostring(add_bracket = True) + "%" + x.tostring(add_bracket = True) + "_K"))
self.exprs_gei.append(-Expr.I(x, newvar))
self.exprs_gei.append(-Expr.Hc(newvark, x + newvar))
self.exprs_gei.append(-Expr.Hc(y, newvar + newvark))
others = ccomp - x - y
if not others.isempty():
self.exprs_gei.append(-Expr.Ic(newvar + newvark, others, x + y))
if gap is not None:
if not isinstance(gap, Expr):
gap = Expr.const(gap)
self.exprs_gei.append(gap.copy() + Expr.I(x, y) - Expr.H(newvark))
if not noaux:
self.auxi += newvar + newvark
return newvar, newvark
def add_esfrl(self, x, y, gap = None, noaux = True):
self.imp_flip()
r = self.add_esfrl_imp(x, y, gap, noaux)
self.imp_flip()
return r
def check_getaux_sfrl(self, sfrl_level = None, sfrl_minsize = 0, sfrl_maxsize = None, sfrl_gap = None, hint_pair = None, hint_aux = None):
"""Return whether implication is true, with auxiliary search results."""
verbose_sfrl = PsiOpts.settings.get("verbose_sfrl", False)
if sfrl_level is None:
sfrl_level = PsiOpts.settings["sfrl_level"]
if sfrl_maxsize is None:
sfrl_maxsize = PsiOpts.settings["sfrl_maxsize"]
if sfrl_gap is None:
sfrl_gap = PsiOpts.settings["sfrl_gap"]
gap = None
gappresent = False
if sfrl_gap == "zero":
gap = Expr.zero()
elif sfrl_gap != "":
gap = Expr.real(sfrl_gap)
gappresent = True
enable_multiple = (sfrl_level >= PsiOpts.SFRL_LEVEL_MULTIPLE)
n = sfrl_maxsize
cs = self
ccomp = cs.auxi + (cs.allcomprv() - cs.aux - cs.auxi)
m = ccomp.size()
sfrlcomp = [[0, 0] for i in range(n)]
tres = None
def check_getaux_sfrl_recur(i):
if i >= sfrl_minsize:
cs = self.copy()
csfrl = Comp.empty()
for i2 in range(i):
sfrlx = sum([ccomp[j] for j in range(m) if sfrlcomp[i2][0] & (1 << j) != 0])
sfrly = sum([ccomp[j] for j in range(m) if sfrlcomp[i2][1] & (1 << j) != 0])
csfrl += cs.add_sfrl_imp(sfrlx, sfrly, gap, noaux = False)
if verbose_sfrl:
print("========== SFRL ========= =========")
print(csfrl)
if enable_multiple:
tres = cs.check_getaux_inplace(must_include = csfrl,
hint_pair = hint_pair, hint_aux = hint_aux)
else:
tres = cs.check_getaux_inplace(must_include = csfrl, single_include = csfrl,
hint_pair = hint_pair, hint_aux = hint_aux)
if tres is not None:
return tres
if i == n:
return None
for size in range(2, m + 1):
for xsize in range(1, size):
for xtuple in itertools.combinations(range(m), xsize):
xmask = sum([1 << x for x in xtuple])
yset = [i for i in range(m) if not (i in xtuple)]
for ytuple in itertools.combinations(yset, size - xsize):
ymask = sum([1 << y for y in ytuple])
sfrlcomp[i][0] = xmask
sfrlcomp[i][1] = ymask
tres = check_getaux_sfrl_recur(i + 1)
if tres is not None:
return tres
if False:
for xmask in range(1, (1 << m) - 1):
ymask = (1 << m) - 1 - xmask
while ymask != 0:
sfrlcomp[i][0] = xmask
sfrlcomp[i][1] = ymask
tres = check_getaux_sfrl_recur(i + 1)
if tres is not None:
return tres
ymask = (ymask - 1) & ~xmask
return check_getaux_sfrl_recur(0)
def substituted_dict_union_plain(self, d):
d2 = dict()
for key, value in d.items():
if self.ispresent(key):
if isinstance(value, list):
d2[key] = value
else:
d2[key] = [value]
cons0 = self.noaux()
cons = []
for subs in itertools.product(*[[(key, v) for v in value] for key, value in d2.items()]):
t = cons0.substituted(list(subs))
t.add_meta("pf_note_case", ["Substitute ", CompArray(list(subs)).add_meta("omit_bracket", True).add_meta("subs", True), ":"])
cons.append(t)
# cons = [self.consonly().noaux()]
# taux = self.aux.copy()
# for key, value in d.items():
# taux -= key
# if not any(c.ispresent(key) for c in cons):
# continue
# if not isinstance(value, list):
# value = [value]
# cons = [c.substituted(key, x) for c in cons for x in value]
if len(cons) == 1:
cons = cons[0]
else:
cons = RegionOp.union(cons)
return cons
def substituted_dict_union(self, d):
taux = self.aux.copy()
for key, value in d.items():
taux -= key
r = self.imp_flippedonly_noaux() >> self.consonly().noaux().substituted_dict_union_plain(d).exists(taux)
if not self.auxi.isempty():
r = r.forall(self.auxi.copy())
return r
def check_getaux(self, hint_pair = None, hint_aux = None):
"""Return whether implication is true, with auxiliary search results."""
truth = PsiOpts.settings["truth"]
if truth is not None:
r = None
with PsiOpts(truth = None):
r = (truth >> self).check_getaux(hint_pair, hint_aux)
return r
indreg = self.get_indreg_checked()
if indreg is not None:
r = None
with PsiOpts(indreg_enabled = False):
r = (indreg >> self).check_getaux(hint_pair, hint_aux)
return r
if PsiOpts.settings["maxent_lex_enabled"] and self.maxent_present():
with PsiOpts(maxent_lex_enabled = False):
for cs in self.maxent_lex_self_gen():
r = cs.check_getaux(hint_pair, hint_aux)
if r is not None:
return r
return None
if self.isplain():
r = self.check()
if r:
return []
return None
write_pf_enabled = PsiOpts.settings.get("proof_enabled", False)
write_pf_repeat_claim = PsiOpts.settings.get("proof_repeat_implicant", False)
oproof = None
if write_pf_enabled:
oproof = PsiOpts.get_proof().copy()
pf = ProofObj.from_region(self if write_pf_repeat_claim else self.consonly(), c = "Claim:")
PsiOpts.set_setting(proof_step_in = pf)
for rr in self.check_getaux_gen(hint_pair, hint_aux):
if iutil.signal_type(rr) == "":
if write_pf_enabled:
if not self.is_getaux_op() and len(rr) > 0:
subdict = Comp.substitute_list_to_dict(rr, multi = True)
if not Comp.substitute_dict_ismulti(subdict):
pf = ProofObj.from_region(None, c = ["Substitute ", CompArray(subdict).add_meta("omit_bracket", True).add_meta("subs", True), ":"])
PsiOpts.set_setting(proof_add = pf)
cs = self.copy()
# Comp.substitute_list(cs, rr, isaux = True)
cs = cs.substituted_dict_union(subdict)
if cs.getaux().isempty():
with PsiOpts(proof_enabled = True):
if isinstance(cs, RegionOp):
cs.check()
else:
cs.check_plain()
PsiOpts.set_setting(proof_step_out = True)
return rr
if write_pf_enabled:
PsiOpts.set_setting(proof_step_out = True)
PsiOpts.set_proof(oproof)
return None
def check_getaux_dict(self, multi = True, **kwargs):
r = self.check_getaux(**kwargs)
if r is None:
return None
return Comp.substitute_list_to_dict(r, multi = multi)
def check_getaux_array(self, multi = True, **kwargs):
r = self.check_getaux(**kwargs)
if r is None:
return None
return CompArray(Comp.substitute_list_to_dict(r, multi = multi)).add_meta("omit_bracket", True).add_meta("subs", True)
def solve(self, method = "c", full = False, proof = None, multi = True, display_reg = None, **kwargs):
"""Return whether implication is true, with auxiliary search results."""
truth = PsiOpts.settings["truth"]
truthcomp = Comp.empty()
if truth is not None:
truthcomp = truth.allcomprv()
if display_reg is None:
display_reg = PsiOpts.settings["solve_display_reg"]
if full:
method = "c,-e,-c,e"
if proof is None:
proof = True
if display_reg is None:
display_reg = True
methods = method.split(",")
if len(methods) > 1:
t0 = None
for cmethod in methods:
t = self.solve(method = cmethod, proof = proof, multi = multi, display_reg = display_reg, **kwargs)
if t0 is None:
t0 = t
if t.truth is not None:
return t
return t0
omethod = method
negate = False
if method.startswith("-"):
method = method[1:]
negate = True
cs = None
if negate:
cs = ~(self.forall_completed(truthcomp))
else:
cs = self.copy()
if method == "c":
proof_kwargs = {key: val for key, val in kwargs.items() if key.startswith("proof_")}
other_kwargs = {key: val for key, val in kwargs.items() if not key.startswith("proof_")}
cproof = None
cont = PsiOpts()
if proof:
cont = PsiOpts(proof_new = True, **proof_kwargs)
with cont:
r = cs.check_getaux(**other_kwargs)
if proof:
cproof = PsiOpts.get_proof()
if r is None:
return CheckResult([], reg = self.copy(), truth = None, method = omethod, display_reg = display_reg)
return CheckResult(Comp.substitute_list_to_dict(r, multi = multi), reg = self.copy(), truth = not negate, method = omethod, getaux = r, proof = cproof, display_reg = display_reg)
if method == "e":
r = None
try:
r = cs.forall_completed(truthcomp).example()
except Exception as err:
r = None
if r is None:
return CheckResult([], reg = self.copy(), truth = None, method = omethod, display_reg = display_reg)
return CheckResult([], reg = self.copy(), truth = not negate, method = omethod, model = r, display_reg = display_reg)
return CheckResult([], reg = self.copy(), truth = None, method = omethod, display_reg = display_reg)
def is_getaux_op(self):
truth = PsiOpts.settings["truth"]
if truth is not None:
return True
if self.isregtermpresent():
return True
return False
def presolve_process(self, relax = True):
if relax:
if PsiOpts.settings.get("presolve_aux_hull", False):
self.simplify_aux_hull()
elif PsiOpts.settings.get("presolve_aux_hull_quick", False):
self.simplify_aux_hull(quick = True)
if PsiOpts.settings.get("presolve_simplify", False):
self.simplify()
return self
else:
if PsiOpts.settings.get("presolve_aux_eq", False):
self.simplify_aux_eq()
elif PsiOpts.settings.get("presolve_aux_eq_quick", False):
self.simplify_aux_eq(quick = True)
return self
def get_indreg(self, reg_add = None, skip_abscont = False):
r = Region.universe()
reals = list(self.allcompreal_exprlist()) + list(iutil.allcompreal_exprlist(reg_add))
for R in reals:
Rt = R.get_maxent_comp()
if Rt is not None:
r.iand_norename(R >= Expr.H(Rt))
allcomprv = self.allcomprv() + r.allcomprv()
if reg_add is not None:
allcomprv = allcomprv + iutil.allcomprv(reg_add)
return r & allcomprv.get_indreg(skip_abscont = skip_abscont)
def get_indreg_checked(self, reg_add = None):
if PsiOpts.settings["indreg_enabled"]:
indreg = self.get_indreg(reg_add)
if indreg is not None and not indreg.isuniverse():
return indreg
return None
def check_getaux_gen(self, hint_pair = None, hint_aux = None, hint_aux_avoid = None):
"""Generator that yields all auxiliary search results."""
truth = PsiOpts.settings["truth"]
if truth is not None:
with PsiOpts(truth = None):
for rr in (truth >> self).check_getaux_gen(hint_pair, hint_aux, hint_aux_avoid):
yield rr
return
indreg = self.get_indreg_checked()
if indreg is not None:
with PsiOpts(indreg_enabled = False):
for rr in (indreg >> self).check_getaux_gen(hint_pair, hint_aux, hint_aux_avoid):
yield rr
return
if PsiOpts.settings["maxent_lex_enabled"] and self.maxent_present():
with PsiOpts(maxent_lex_enabled = False):
for cs in self.maxent_lex_self_gen():
for rr in cs.check_getaux_gen(hint_pair, hint_aux, hint_aux_avoid):
yield rr
return
if self.isregtermpresent():
cs = RegionOp.inter([self])
for rr in cs.check_getaux_gen(hint_pair, hint_aux, hint_aux_avoid):
yield rr
return
write_pf_enabled = PsiOpts.settings.get("proof_enabled", False)
cs = self.copy()
cs.presolve_process()
cs.simplify_quick(zero_group = 2)
cs.split()
PsiOpts.settings["proof_enabled"] = False
res = None
sfrl_level = PsiOpts.settings["sfrl_level"]
if hint_aux_avoid is None:
hint_aux_avoid = []
hint_aux_avoid = hint_aux_avoid + self.get_aux_avoid_list()
for rr in cs.check_getaux_inplace_gen(hint_pair = hint_pair, hint_aux = hint_aux, hint_aux_avoid = hint_aux_avoid):
PsiOpts.settings["proof_enabled"] = write_pf_enabled
yield rr
PsiOpts.settings["proof_enabled"] = False
if sfrl_level > 0:
res = cs.check_getaux_sfrl(sfrl_minsize = 1, hint_pair = hint_pair, hint_aux = hint_aux)
if res is not None:
PsiOpts.settings["proof_enabled"] = write_pf_enabled
yield res
PsiOpts.settings["proof_enabled"] = False
PsiOpts.settings["proof_enabled"] = write_pf_enabled
def aux_gen(self, hint_pair = None, hint_aux = None, hint_aux_avoid = None):
"""Generator that yields all auxiliary search results."""
return self.check_getaux_gen(hint_pair, hint_aux, hint_aux_avoid)
def check(self):
"""Return whether implication is true"""
if iutil.get_solver() == "z3":
return self.check_z3()
truth = PsiOpts.settings["truth"]
if truth is not None:
with PsiOpts(truth = None):
return (truth >> self).check()
indreg = self.get_indreg_checked()
if indreg is not None:
with PsiOpts(indreg_enabled = False):
return (indreg >> self).check()
if PsiOpts.settings["maxent_lex_enabled"] and self.maxent_present():
with PsiOpts(maxent_lex_enabled = False):
for cs in self.maxent_lex_self_gen():
if cs.check():
return True
return False
if self.isplain():
return self.check_plain()
return self.check_getaux() is not None
def assumption(self, mode = None):
"""Retrieve the strengthened assumptions for proving this region. The implication in this
region must be true if this assumption does not hold.
"""
return RegionOp.inter([self]).assumption(mode = mode)
def truth(self):
"""The region given by the assumptions.
"""
r = Region.universe()
truth = PsiOpts.settings["truth"]
if truth is not None:
r &= truth
indreg = self.get_indreg_checked()
if indreg is not None:
r &= indreg
return r
def evalcheck(self, f):
truth = PsiOpts.settings["truth"]
if truth is not None:
with PsiOpts(truth = None):
return (truth >> self).evalcheck(f)
indreg = self.get_indreg_checked()
if indreg is not None:
with PsiOpts(indreg_enabled = False):
return (indreg >> self).evalcheck(f)
ceps = PsiOpts.settings["eps_check"]
for x in self.exprs_gei:
if not float(f(x)) >= -ceps:
return True
for x in self.exprs_eqi:
if not abs(float(f(x))) <= ceps:
return True
for x in self.exprs_ge:
if not float(f(x)) >= -ceps:
return False
for x in self.exprs_eq:
if not abs(float(f(x))) <= ceps:
return False
return True
def evalcheck_vec(self, f):
truth = PsiOpts.settings["truth"]
if truth is not None:
with PsiOpts(truth = None):
return (truth >> self).evalcheck_vec(f)
indreg = self.get_indreg_checked()
if indreg is not None:
with PsiOpts(indreg_enabled = False):
return (indreg >> self).evalcheck_vec(f)
ceps = PsiOpts.settings["eps_check"]
qi = 1
q = 1
for x in self.exprs_gei:
t = f(x) >= -ceps
qi = qi * t
for x in self.exprs_eqi:
t = abs(f(x)) <= ceps
qi = qi * t
for x in self.exprs_ge:
t = f(x) >= -ceps
q = q * t
for x in self.exprs_eq:
t = abs(f(x)) <= ceps
q = q * t
return 1 - (qi * (1 - q))
def eval_max_violate(self, f):
truth = PsiOpts.settings["truth"]
if truth is not None:
with PsiOpts(truth = None):
return (truth >> self).eval_max_violate(f)
indreg = self.get_indreg_checked()
if indreg is not None:
with PsiOpts(indreg_enabled = False):
return (indreg >> self).eval_max_violate(f)
ceps = PsiOpts.settings["eps_check"]
for x in self.exprs_gei:
t = float(f(x))
if not numpy.isnan(t) and not t >= -ceps:
return 0.0
for x in self.exprs_eqi:
t = float(f(x))
if not numpy.isnan(t) and not abs(t) <= ceps:
return 0.0
r = 0.0
for x in self.exprs_ge:
t = float(f(x))
if numpy.isnan(t):
return numpy.inf
r = max(r, -t)
for x in self.exprs_eq:
t = float(f(x))
if numpy.isnan(t):
return numpy.inf
r = max(r, abs(t))
return r
def eval_sum_violate(self, f, pow = 1, leak = 0.1):
# truth = PsiOpts.settings["truth"]
# if truth is not None:
# with PsiOpts(truth = None):
# return (truth >> self).eval_sum_violate(f, pow = pow)
r = 0.0
for x in self.exprs_ge:
t = ConcReal.unbox(f(x))
if numpy.isnan(float(t)):
continue
if t < 0:
r = r + (-t) ** pow
else:
r = r - (t ** pow) * leak
for x in self.exprs_eq:
t = ConcReal.unbox(f(x))
if numpy.isnan(float(t)):
continue
if t < 0:
r = r + (-t) ** pow
else:
r = r + t ** pow
return r
def expr_sum_violate(self, *args, **kwargs):
return Expr.fcn(lambda P: self.eval_sum_violate(P, *args, **kwargs))
def example(self, card = None, denom = None):
truth = PsiOpts.settings["truth"]
if truth is not None:
with PsiOpts(truth = None):
return (truth & self).example(card = card)
indreg = self.get_indreg_checked()
if indreg is not None:
with PsiOpts(indreg_enabled = False):
return (indreg & self).example(card = card)
cs = self.exists(self.allcomprv() - self.getaux() - self.getauxi())
if card is None:
card = PsiOpts.settings["opt_example_card"]
if card is None or isinstance(card, int):
card = [card]
if denom is None:
denom = PsiOpts.settings["max_denom_try"]
for ccard in card:
P = ConcModel()
cont = PsiOpts(opt_num_points_mul = PsiOpts.settings["example_opt_num_points_mul"])
if ccard is not None:
cont = PsiOpts(opt_num_points_mul = PsiOpts.settings["example_opt_num_points_mul"],
opt_aux_card = ccard)
with cont:
if not P[cs]:
continue
if denom > 0:
P2 = P.opt_model().copy()
P2.fraction_snap(denom = denom, eps = numpy.inf)
if P2[cs.noaux()]:
P2.set_force_float(5)
return P2
return P.opt_model()
return None
def add_real_to(self, P, rel_gap = 0.15):
reals = self.allcomprealvar_exprlist()
if len(reals) == 0:
return
for real in reals:
P[real] = 0.0 #"var"
exprs_ge_real = [x for x in self.exprs_ge if not x.type_only(TermType.REAL).iszero()]
exprs_eq_real = [x for x in self.exprs_eq if not x.type_only(TermType.REAL).iszero()]
for it in range(2):
for real in reals:
cinf = 1e20
lb = -cinf
ub = cinf
for x in exprs_ge_real:
c = x.get_coeff(real)
if c == 0:
continue
v = float(P[x - real * c])
if c > 0:
lb = max(lb, -v / c)
else:
ub = min(ub, -v / c)
for x in exprs_eq_real:
c = x.get_coeff(real)
if c == 0:
continue
v = float(P[x - real * c])
lb = max(lb, -v / c)
ub = min(ub, -v / c)
if lb != -cinf and ub != cinf:
P[real] = (lb + ub) * 0.5
elif lb != -cinf:
P[real] = lb * (1.0 + rel_gap * (1 if lb > 0 else -1))
elif ub != -cinf:
P[real] = ub * (1.0 - rel_gap * (1 if ub > 0 else -1))
def example_bsc(self, addrv = None, bnet = None, crossover = 0.1, rel_gap = 0.15):
P = ConcModel()
if bnet is None:
bnet = self.get_bayesnet().tsorted()
else:
bnet = bnet.tsorted()
va = bnet.index.comprv + self.allcomprv()
if addrv is not None:
va += addrv
va0 = Comp.empty()
for a in va:
cpa = va0
if bnet.index.comprv.ispresent(a):
cpa = bnet.get_parents(a)
shape_in = tuple(P.get_card(t) for t in cpa)
in_max = sum(x - 1 for x in shape_in)
p = ConcDist(shape = (shape_in, 2))
for xs in itertools.product(*[range(x) for x in shape_in]):
c = sum(xs) * (1.0 - 2 * crossover) / in_max + crossover if in_max else 0.5
# c = (sum(xs) + crossover) / (in_max + crossover * 2)
p[xs + (0,)] = 1.0 - c
p[xs + (1,)] = c
P[a | cpa] = p
va0 += a
self.add_real_to(P, rel_gap = rel_gap)
return P
# r_real = Region.universe()
# r_real.exprs_ge = exprs_ge_real
# r_real.exprs_eq = exprs_eq_real
# # def leaky(x):
# # if x < 0:
# # return x
# # return x * rel_gap
# # print(exprs_ge_real)
# # print(reals)
# bd_gap = 0.001
# obj = lambda P: sum(1.0 / (max(P[a].get_x() + bd_gap, 0.0) + bd_gap) for a in exprs_ge_real) + sum(P[a].get_x() ** 2 for a in reals) * rel_gap
# P.minimize(obj, [P[a] for a in reals], r_real)
# print(obj(P))
# # P.maximize(lambda P: sum(P[a] for a in exprs_ge_real) ,
# # [P[a] for a in reals], r_real)
# return P
def implies(self, other, quick = False, bnet = None, **kwargs):
"""Whether self implies other"""
if kwargs:
k0, k1 = PsiOpts.setting_strengthen_split(kwargs)
cs = self
if k0:
with PsiOpts(**{"simplify_" + key: val for key, val in k0.items()}):
cs = cs.simplified()
if k1:
with PsiOpts(**{"simplify_" + key: val for key, val in k1.items()}):
other = other.simplified()
return cs.implies(other, quick = quick)
if quick:
return (self <= other).check_quick(bnet = bnet)
else:
return (self <= other).check()
def implies_getaux(self, other, hint_pair = None, hint_aux = None):
"""Whether self implies other, with auxiliary search result"""
res = (self <= other).check_getaux(hint_pair, hint_aux)
if res is None:
return None
return res
#auxlist = other.aux.varlist + self.auxi.varlist
#return [(Comp([auxlist[i]]), res[i][1]) for i in range(len(res))]
def implies_getaux_gen(self, other, hint_pair = None, hint_aux = None):
"""Whether self implies other, yield all auxiliary search result"""
for rr in (self <= other).check_getaux_gen(hint_pair, hint_aux):
yield rr
def equiv(self, other, quick = False, **kwargs):
"""Whether self is equivalent to other"""
return self.implies(other, quick = quick, **kwargs) and other.implies(self, quick = quick, **kwargs)
# def allcomp(self):
# index = IVarIndex()
# self.record_to(index)
# return index.comprv + index.compreal
# def allcomprv(self):
# index = IVarIndex()
# self.record_to(index)
# return index.comprv
# def allcompreal(self):
# index = IVarIndex()
# self.record_to(index)
# return index.compreal
def indep_components(self):
return self.get_bayesnet().indep_components()
def abscont_preserved(self):
return self.get_ic(include_ic = False, include_hc = True) <= 0
def maxent_lex_region(self, a, s, varadd = None):
"""Region given by the lex operation.
Mokshay Madiman, Adam W Marcus, and Prasad Tetali, "Entropy and set cardinality inequalities
for partition-determined functions", Random Structures & Algorithms 40, 4 (2012), pp. 399--424.
"""
vars = self.allcomprv() + iutil.allcomprv(a) + iutil.allcomprv(s) + iutil.allcomprv(varadd)
indreg = self.get_indreg_checked(reg_add = vars)
if indreg is not None:
with PsiOpts(indreg_enabled = False):
return (self & indreg).maxent_lex_region(a, s, varadd = varadd)
r = Region.universe()
for b in vars:
for mask in range(1 << len(s)):
s0 = sum((x for i, x in enumerate(s) if mask & (1 << i)), Comp.empty())
s1 = sum((x for i, x in enumerate(s) if not mask & (1 << i)), Comp.empty())
if (self >> (~Expr.Hc(b, s0) & ~Expr.Hc(a, b + s1))).check_plain():
r.iand_norename(~Expr.Hc(s0, b))
return r
def minimal_dependency(self, x, ys):
if not (self >> ~Expr.Hc(x, sum(ys, Comp.empty()))).check_plain():
return None
r = (1 << len(ys)) - 1
for i in range(len(ys)):
a = sum((b for j, b in enumerate(ys) if r & (1 << j) and j != i), Comp.empty())
if (self >> ~Expr.Hc(x, a)).check_plain():
r -= 1 << i
return r
def maxent_lex_region_gen(self, vs, varadd = None, nochange = None):
"""Region given by the lex operation.
Mokshay Madiman, Adam W Marcus, and Prasad Tetali, "Entropy and set cardinality inequalities
for partition-determined functions", Random Structures & Algorithms 40, 4 (2012), pp. 399--424.
"""
if nochange is None:
nochange = Comp.empty()
vars = self.allcomprv() + iutil.allcomprv(vs) + iutil.allcomprv(varadd)
indreg = self.get_indreg_checked(reg_add = vars)
if indreg is not None:
with PsiOpts(indreg_enabled = False):
for r in (self & indreg).maxent_lex_region_gen(vs, varadd = varadd, nochange = nochange):
yield r
return
yield self
icomps = self.indep_components()
if len(icomps) <= 1:
icomps = [vars]
mask_force = 0
cvs = []
for v, force in vs:
if nochange.ispresent(v):
if force:
return
continue
dep = self.minimal_dependency(v, icomps)
if dep is None:
if force:
return
continue
deplist = [y for j, y in enumerate(icomps) if dep & (1 << j)]
# if not force and len(deplist) <= 1:
# continue
if any(nochange.ispresent(x) for x in deplist):
if force:
return
continue
treg = self.maxent_lex_region(v, deplist, varadd = varadd)
if not force and treg.isuniverse():
continue
if force:
mask_force += 1 << len(cvs)
cvs.append((v, dep, treg, force))
for mask in range(1, 1 << len(cvs)):
if mask & mask_force != mask_force:
continue
tvs = []
sdep = 0
r = Region.universe()
for i, (v, dep, treg, force) in enumerate(cvs):
if not mask & (1 << i):
continue
if sdep & dep:
sdep = -1
break
sdep |= dep
depsum = sum((y for j, y in enumerate(icomps) if dep & (1 << j)), Comp.empty())
tvs.append(depsum)
r.iand_norename((H0(v) == H(v)).add_meta("pf_note", ["lex ", depsum]))
r.iand_norename(treg)
if sdep == -1:
continue
depleft = [y for j, y in enumerate(icomps) if not sdep & (1 << j)]
cs = Region.universe()
if depleft:
cs = self.copy()
cs.remove_relax(vars - sum(depleft, Comp.empty()))
r = (r & cs & self.abscont_preserved() & self.get_indreg(skip_abscont = True) & indep(*depleft, *tvs))
r = r.simplified_quick()
yield r
def maxent_ub_list(self):
r = []
for x in self.exprs_ge:
for a, c in x.terms:
if c >= 0:
continue
a2 = a.get_maxent_comp()
if a2 is not None and a2 not in r:
r.append(a2)
for x in self.exprs_eq:
for a, c in x.terms:
a2 = a.get_maxent_comp()
if a2 is not None and a2 not in r:
r.append(a2)
return r
def abscont_nochange_comp(self):
r = Comp.empty()
for x in self.exprs_ge + self.exprs_eq:
for a, c in x.terms:
if a.get_maxent_comp() is None:
r += a.allcomprv()
return r
def maxent_present(self):
return any(x.get_maxent_comp() is not None for x in self.allcompreal_exprlist())
def maxent_lex_self_gen(self, varadd = None):
vs = [(x, True) for x in self.maxent_ub_list()]
for tr in self.imp_flippedonly_noaux().maxent_lex_region_gen(vs, varadd=varadd, nochange=self.abscont_nochange_comp()):
r = (tr >> self.consonly()).forall(self.auxi)
# print(r)
yield r
def allcomprealvar(self):
r = Comp.empty()
for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi:
r += x.allcomprealvar()
return r
# index = IVarIndex()
# self.record_to(index)
# return index.compreal - Comp([IVar.eps(), IVar.one()])
def allcompreal_exprlist(self):
r = ExprArray([])
for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi:
r.iadd_noduplicate(x.allcompreal_exprlist())
return r
def allcomprealvar_exprlist(self):
r = ExprArray([])
for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi:
r.iadd_noduplicate(x.allcomprealvar_exprlist())
return r
def allcomprv_noaux(self):
return self.allcomprv() - self.getauxall()
def aux_remove(self):
self.aux = Comp.empty()
self.auxi = Comp.empty()
return self
def completed_semigraphoid_ic(self, vs = None, max_iter = None):
verbose = PsiOpts.settings.get("verbose_semigraphoid", False)
if vs is not None:
tvs = []
for t in vs:
if isinstance(t, Comp):
tvs.append((t, t))
else:
tvs.append(t)
if len(tvs) == 0:
return Expr.zero()
vs = tvs
def mask_impl(a, b):
a0, a1, az = a
b0, b1, bz = b
a0x = a0 & (bz & ~az)
a0 &= ~a0x
az |= a0x
a1x = a1 & (bz & ~az)
a1 &= ~a1x
az |= a1x
if az != bz:
return False
return a0 | b0 == a0 and a1 | b1 == a1
icexpr = self.get_ic()
index = IVarIndex()
icexpr.record_to(index)
icl = set()
if verbose:
print("========== SEMIGRAPHOID =========")
print(icexpr)
print("=====================================")
for a, c in icexpr.terms:
if len(a.x) == 1:
mz = index.get_mask(a.z)
m0 = index.get_mask(a.x[0]) & ~mz
icl.add((m0, m0, mz))
elif len(a.x) == 2:
mz = index.get_mask(a.z)
m0 = index.get_mask(a.x[0]) & ~mz
m1 = index.get_mask(a.x[1]) & ~mz
if m0 > m1:
m0, m1 = m1, m0
icl.add((m0, m1, mz))
#for a0, a1, az in icl:
# print(str(Expr.Ic(index.from_mask(a0), index.from_mask(a1), index.from_mask(az))))
citer = 0
iclw = icl.copy()
did = True
while did:
did = False
icl2 = icl.copy()
for a0k, a1k, azk in icl:
if max_iter is not None and citer > max_iter:
break
for b0k, b1k, bzk in icl:
if max_iter is not None and citer > max_iter:
break
if (a0k, a1k, azk) == (b0k, b1k, bzk):
continue
if azk & ~(b0k | b1k | bzk) != 0:
continue
if bzk & ~(a0k | a1k | azk) != 0:
continue
for aj in range(2):
for bj in range(2):
citer += 1
if max_iter is not None and citer > max_iter:
break
a0, a1, az = a0k, a1k, azk
b0, b1, bz = b0k, b1k, bzk
if aj != 0:
a0, a1 = a1, a0
a0o, a1o, azo = a0, a1, az
if bj != 0:
b0, b1 = b1, b0
b0o, b1o, bzo = b0, b1, bz
if a0 & b0 == 0:
continue
b0x = b0 & (az & ~bz)
b0 &= ~b0x
bz |= b0x
b1z = b1 | bz
a0x = a0 & (b1z & ~az)
a0 &= ~a0x
az |= a0x
a1x = a1 & (b1z & ~az)
a1 &= ~a1x
az |= a1x
b1 &= az
a0 &= b0 & ~bz
a1 &= ~bz
b1 &= ~bz
# if verbose:
# print(str(Expr.Ic(index.from_mask(a0o), index.from_mask(a1o), index.from_mask(azo)))
# + " - " + str(Expr.Ic(index.from_mask(b0o), index.from_mask(b1o), index.from_mask(bzo)))
# + " : " + str(Expr.Ic(index.from_mask(a0), index.from_mask(a1 | b1), index.from_mask(bz)))
# + " / " + str(index.from_mask(az)) + "=" + str(index.from_mask(b1 | bz))
# )
if az != b1 | bz:
continue
if a0 == 0 or a1 | b1 == 0:
continue
t = (a0, a1 | b1, bz)
if mask_impl((a0o, a1o, azo), t) or mask_impl((b0o, b1o, bzo), t):
continue
if verbose:
print(str(citer) + ": " + str(Expr.Ic(index.from_mask(a0o), index.from_mask(a1o), index.from_mask(azo)))
+ " & " + str(Expr.Ic(index.from_mask(b0o), index.from_mask(b1o), index.from_mask(bzo)))
+ " -> " + str(Expr.Ic(index.from_mask(a0), index.from_mask(a1 | b1), index.from_mask(bz)))
)
if t[0] > t[1]:
t = (t[1], t[0], t[2])
if t in iclw:
continue
#print(str(Expr.Ic(index.from_mask(t[0]), index.from_mask(t[1]), index.from_mask(t[2]))))
icl2.add(t)
iclw.add(t)
did = True
icl.clear()
for a0, a1, az in icl2:
for b0, b1, bz in icl2:
if (a0, a1, az) == (b0, b1, bz):
continue
if (mask_impl((b0, b1, bz), (a0, a1, az))
or mask_impl((b0, b1, bz), (a1, a0, az))):
break
else:
icl.add((a0, a1, az))
if vs is not None:
vmasks = [index.get_mask(b) for a, b in vs]
vvars = CompArray([a for a, b in vs])
icl2 = set()
for a0, a1, az in icl:
b0 = 0
b1 = 0
bz = 0
for i in range(len(vs)):
if a0 | vmasks[i] == a0:
b0 |= 1 << i
if a1 | vmasks[i] == a1:
b1 |= 1 << i
if b0 == 0 or b1 == 0:
continue
def recur(i, zmask, azvis):
if i == len(vs):
if azvis & az == az:
c0 = b0 & ~zmask
c1 = b1 & ~zmask
if c0 != 0 and c1 != 0:
if c0 > c1:
c0, c1 = c1, c0
icl2.add((c0, c1, zmask))
return
vm = vmasks[i]
recur(i + 1, zmask, azvis)
if vm & az != 0 and (a0 | a1 | az) & vm == vm:
recur(i + 1, zmask + (1 << i), azvis | vm)
recur(0, 0, 0)
icl.clear()
for a0, a1, az in icl2:
for b0, b1, bz in icl2:
if (a0, a1, az) == (b0, b1, bz):
continue
if (mask_impl((b0, b1, bz), (a0, a1, az))
or mask_impl((b0, b1, bz), (a1, a0, az))):
break
else:
icl.add((a0, a1, az))
r = Expr.zero()
for a0, a1, az in icl:
r += Expr.Ic(vvars.from_mask(a0), vvars.from_mask(a1), vvars.from_mask(az))
return r
r = Expr.zero()
for a0, a1, az in icl:
r += Expr.Ic(index.from_mask(a0), index.from_mask(a1), index.from_mask(az))
return r
def completed_semigraphoid_ic_new(self, vs = None, max_iter = None, include_hc = True):
verbose = PsiOpts.settings.get("verbose_semigraphoid", False)
if vs is not None:
tvs = []
for t in vs:
if isinstance(t, Comp):
tvs.append((t, t))
else:
tvs.append(t)
if len(tvs) == 0:
return Expr.zero()
vs = tvs
def mask_impl(a, b):
a0, a1, az = a
b0, b1, bz = b
a0x = a0 & (bz & ~az)
a0 &= ~a0x
az |= a0x
a1x = a1 & (bz & ~az)
a1 &= ~a1x
az |= a1x
if az != bz:
return False
return a0 | b0 == a0 and a1 | b1 == a1
icexpr = self.get_ic(include_hc = include_hc)
index = IVarIndex()
icexpr.record_to(index)
icl = set()
nvar = len(index.comprv)
# fcn_masks = [1 << i for i in range(nvar)]
fcns = []
def fcn_expand(x):
did = True
while did:
did = False
for mz, m0 in fcns:
if x & mz == mz and x | m0 != x:
x |= m0
did = True
return x
def clean(a):
a0, a1, az = a
# for i in range(nvar):
# if a0 & (1 << i):
# a0 |= fcn_masks[i]
# if a1 & (1 << i):
# a1 |= fcn_masks[i]
# if az & (1 << i):
# az |= fcn_masks[i]
az = fcn_expand(az)
a0 = fcn_expand(a0 | az) & ~az
a1 = fcn_expand(a1 | az) & ~az
if a0 > a1:
a0, a1 = a1, a0
return (a0, a1, az)
if verbose:
print("========== SEMIGRAPHOID =========")
print(icexpr)
print("=====================================")
for a, c in icexpr.terms:
if len(a.x) == 1:
mz = index.get_mask(a.z)
m0 = index.get_mask(a.x[0])
fcns.append((mz, mz | m0))
fcns = [(mz, fcn_expand(m0)) for mz, m0 in fcns]
for a, c in icexpr.terms:
if len(a.x) == 1:
pass
# mz = index.get_mask(a.z)
# m0 = index.get_mask(a.x[0]) & ~mz
# m1 = ((1 << nvar) - 1) & ~mz
# # icl.add((m0, m0, mz))
# icl.add((m0, m1, mz))
elif len(a.x) == 2:
mz = index.get_mask(a.z)
m0 = index.get_mask(a.x[0]) & ~mz
m1 = index.get_mask(a.x[1]) & ~mz
# if m0 > m1:
# m0, m1 = m1, m0
icl.add(clean((m0, m1, mz)))
# for a0, a1, az in icl:
# print(str(Expr.Ic(index.from_mask(a0), index.from_mask(a1), index.from_mask(az))))
citer = 0
iclw = icl.copy()
did = True
while did:
did = False
icl2 = icl.copy()
# Check contraction axiom
for a0k, a1k, azk in icl:
if max_iter is not None and citer > max_iter:
break
if PsiOpts.is_timer_ended():
break
for b0k, b1k, bzk in icl:
if max_iter is not None and citer > max_iter:
break
if (a0k, a1k, azk) == (b0k, b1k, bzk):
continue
if azk & ~(b0k | b1k | bzk) != 0:
continue
if bzk & ~(a0k | a1k | azk) != 0:
continue
for aj in range(2):
for bj in range(2):
citer += 1
if max_iter is not None and citer > max_iter:
break
a0, a1, az = a0k, a1k, azk
b0, b1, bz = b0k, b1k, bzk
if aj != 0:
a0, a1 = a1, a0
a0o, a1o, azo = a0, a1, az
if bj != 0:
b0, b1 = b1, b0
b0o, b1o, bzo = b0, b1, bz
if a0 & b0 == 0:
continue
b0x = b0 & (az & ~bz)
b0 &= ~b0x
bz |= b0x
b1z = b1 | bz
a0x = a0 & (b1z & ~az)
a0 &= ~a0x
az |= a0x
a1x = a1 & (b1z & ~az)
a1 &= ~a1x
az |= a1x
b1 &= az
a0 &= b0 & ~bz
a1 &= ~bz
b1 &= ~bz
# if verbose: # REMOVE
# print(str(Expr.Ic(index.from_mask(a0o), index.from_mask(a1o), index.from_mask(azo)))
# + " - " + str(Expr.Ic(index.from_mask(b0o), index.from_mask(b1o), index.from_mask(bzo)))
# + " : " + str(Expr.Ic(index.from_mask(a0), index.from_mask(a1 | b1), index.from_mask(bz)))
# + " / " + str(index.from_mask(az)) + "=" + str(index.from_mask(b1 | bz))
# )
if az != b1 | bz:
continue
if a0 == 0 or a1 | b1 == 0:
continue
t = clean((a0, a1 | b1, bz))
if mask_impl((a0o, a1o, azo), t) or mask_impl((b0o, b1o, bzo), t):
continue
if verbose:
print(str(citer) + ": " + str(Expr.Ic(index.from_mask(a0o), index.from_mask(a1o), index.from_mask(azo)))
+ " & " + str(Expr.Ic(index.from_mask(b0o), index.from_mask(b1o), index.from_mask(bzo)))
+ " -> " + str(Expr.Ic(index.from_mask(a0), index.from_mask(a1 | b1), index.from_mask(bz)))
)
# if t[0] > t[1]:
# t = (t[1], t[0], t[2])
if t in iclw:
continue
#print(str(Expr.Ic(index.from_mask(t[0]), index.from_mask(t[1]), index.from_mask(t[2]))))
icl2.add(t)
iclw.add(t)
did = True
icl.clear()
for a0, a1, az in icl2:
for b0, b1, bz in icl2:
if (a0, a1, az) == (b0, b1, bz):
continue
if (mask_impl((b0, b1, bz), (a0, a1, az))
or mask_impl((b0, b1, bz), (a1, a0, az))):
break
else:
icl.add((a0, a1, az))
if vs is not None:
vmasks = [index.get_mask(b) for a, b in vs]
vvars = CompArray([a for a, b in vs])
icl2 = set()
for a0, a1, az in icl:
b0 = 0
b1 = 0
bz = 0
for i in range(len(vs)):
if a0 | vmasks[i] == a0:
b0 |= 1 << i
if a1 | vmasks[i] == a1:
b1 |= 1 << i
if b0 == 0 or b1 == 0:
continue
def recur(i, zmask, azvis):
if i == len(vs):
if azvis & az == az:
c0 = b0 & ~zmask
c1 = b1 & ~zmask
if c0 != 0 and c1 != 0:
if c0 > c1:
c0, c1 = c1, c0
icl2.add((c0, c1, zmask))
return
vm = vmasks[i]
recur(i + 1, zmask, azvis)
if vm & az != 0 and (a0 | a1 | az) & vm == vm:
recur(i + 1, zmask + (1 << i), azvis | vm)
recur(0, 0, 0)
icl.clear()
for a0, a1, az in icl2:
for b0, b1, bz in icl2:
if (a0, a1, az) == (b0, b1, bz):
continue
if (mask_impl((b0, b1, bz), (a0, a1, az))
or mask_impl((b0, b1, bz), (a1, a0, az))):
break
else:
icl.add((a0, a1, az))
r = Expr.zero()
for a0, a1, az in icl:
r += Expr.Ic(vvars.from_mask(a0), vvars.from_mask(a1), vvars.from_mask(az))
for mz, m0 in fcns:
r += Expr.Hc(vvars.from_mask(m0 & ~mz), vvars.from_mask(mz))
return r
r = Expr.zero()
for a0, a1, az in icl:
r += Expr.Ic(index.from_mask(a0), index.from_mask(a1), index.from_mask(az))
for mz, m0 in fcns:
r += Expr.Hc(index.from_mask(m0 & ~mz), index.from_mask(mz))
return r
def completed_semigraphoid(self, vs = None, max_iter = None):
""" Use semi-graphoid axioms to deduce more conditional independence.
Judea Pearl and Azaria Paz, "Graphoids: a graph-based logic for reasoning
about relevance relations", Advances in Artificial Intelligence (1987), pp. 357--363.
"""
return self.completed_semigraphoid_ic(vs = vs, max_iter = max_iter) <= 0
def eliminated_ic(self, w):
icexpr = self.get_ic()
index = IVarIndex()
icexpr.record_to(index)
icl = set()
r = Expr.zero()
for a, c in icexpr.terms:
if a.z.ispresent(w):
continue
b = a.copy()
for i in range(len(b.x)):
b.x[i] -= w
if not b.iszero():
r += Expr.fromterm(b)
return r <= 0
def completed_sfrl(self, gap = None, max_iter = None):
index = IVarIndex()
self.record_to(index)
n = len(index.comprv)
cs = self.copy()
tmpvar = Comp.empty()
for i in range(n):
for j in range(n):
if i == j:
continue
tmpvar += cs.add_sfrl(index.comprv[i], index.comprv[j], gap, noaux = False)
cs2 = cs.completed_semigraphoid(max_iter = max_iter)
cs3 = cs2.eliminated_ic(tmpvar)
return cs3
def convexified(self, v = None, bnet = None, q = None, inp = False, forall = False):
"""Convexify with respect to random variables v by a time sharing RV q, return result"""
r = self.copy()
if r.isregtermpresent():
r = r.flattened(minmax_elim = True)
r = r.tosimple_safe()
if r is None:
return None
v = Region.get_allcomp(v)
qname = "Q"
if q is not None:
qname = str(q)
qname = r.name_avoid(qname)
q = Comp.rv(qname)
allcomp = r.allcomprv()
r.condition(q)
cmi = None
if inp:
cmi = Expr.Ic(q, allcomp - r.getauxall() - v - r.inp - r.oup, r.inp) == 0
elif bnet is None:
cmi = Expr.I(q, allcomp - r.getauxall() - v) == 0
else:
bnet2 = None
if isinstance(bnet, BayesNet):
bnet2 = bnet.copy()
else:
bnet2 = BayesNet(bnet)
for v2 in v:
bnet2 += (q, v2)
cmi = bnet2.get_region()
if forall:
r |= ~cmi
return r.forall(q)
else:
r &= cmi
return r.exists(q)
def tounion(self):
return RegionOp.pack_type(self, RegionType.UNION)
def convexified_diag(self, v = None, cross_only = True, skip_simplify = False):
"""Convexify with respect to the real variables in v along the diagonal, return result"""
if v is None:
v = self.allcomprealvar_exprlist()
index = IVarIndex()
self.record_to(index)
for x in v:
x.record_to(index)
namemap = [dict(), dict()]
v_new = [[], []]
toelim = Expr.zero()
for x in v:
cname = x.get_name()
for it in range(2):
nname = index.name_avoid(cname)
namemap[it][cname] = nname
v_new[it].append(Expr.real(nname))
Expr.real(nname).record_to(index)
toelim += Expr.real(nname)
r = RegionOp.empty()
ru = self.tounion()
if cross_only:
r = ru.copy()
for i, j in itertools.combinations(range(len(v)), 2):
v0j = v[i] + v[j] - v_new[0][i]
v1j = v[i] + v[j] - v_new[1][i]
for k0, a0 in enumerate(ru):
a0t = a0.copy()
a0t.substitute(v[i], v_new[0][i])
a0t.substitute(v[j], v0j)
for k1, a1 in enumerate(ru):
if cross_only and k0 == k1:
continue
a1t = a1.copy()
a1t.substitute(v[i], v_new[1][i])
a1t.substitute(v[j], v1j)
t = a0t & a1t
t &= (v[i] >= v_new[0][i]) & (v[i] <= v_new[1][i])
#print(str(i) + " , " + str(j))
r |= t.exists(v_new[0][i]+v_new[1][i])
if not skip_simplify:
return r.simplified()
else:
return r
def isconvex(self, v = None, bnet = None, inp = False):
"""Check whether region is convex with respect to random variables v and real variables.
False return value does NOT necessarily mean region is not convex
"""
t = self.convexified(v, bnet, inp = inp)
if t is None:
return False
return t.implies(self)
#return ((self + self) / 2).implies(self)
def clean_eps(self):
if not self.ispresent(IVar.eps()):
return
for a in self.exprs_ge + self.exprs_gei:
eps_coeff = a.get_coeff(Term.eps())
a.substitute(Expr.eps(), Expr.zero())
if eps_coeff < 0:
a -= Expr.eps()
def simplify_quick(self, reg = None, zero_group = 0):
"""Simplify a region in place, without linear programming
Optional argument reg with constraints assumed to be true
zero_group = 2: group all nonnegative terms as a single inequality
"""
if not PsiOpts.settings.get("simplify_enabled", False):
return self
write_pf_enabled = (PsiOpts.settings.get("proof_enabled", False)
and PsiOpts.settings.get("proof_step_simplify", False))
if write_pf_enabled:
prevself = self.copy()
#self.remove_missing_aux()
if reg is None:
reg = Region.universe()
for x in self.exprs_ge:
x.simplify_quick(reg = reg)
for x in self.exprs_eq:
x.simplify_quick(reg = reg)
self.clean_eps()
index = IVarIndex()
self.record_to(index)
gemask = [index.get_mask(x.allcomprv_shallow()) for x in self.exprs_ge]
eqmask = [index.get_mask(x.allcomprv_shallow()) for x in self.exprs_eq]
did = True
if True:
did = False
for i in range(len(self.exprs_ge)):
if not self.exprs_ge[i].iszero():
for j in range(i):
if not self.exprs_ge[j].iszero() and gemask[i] == gemask[j]:
ratio = self.exprs_ge[i].get_ratio(self.exprs_ge[j], skip_simplify = True)
if ratio is None:
continue
if ratio > PsiOpts.settings["eps"]:
self.exprs_ge[i] = Expr.zero()
gemask[i] = 0
did = True
break
elif ratio < -PsiOpts.settings["eps"]:
self.exprs_eq.append(self.exprs_ge[i])
eqmask.append(gemask[i])
self.exprs_ge[i] = Expr.zero()
gemask[i] = 0
self.exprs_ge[j] = Expr.zero()
gemask[j] = 0
did = True
break
for i in range(len(self.exprs_ge)):
if not self.exprs_ge[i].iszero():
for j in range(len(self.exprs_eq)):
if not self.exprs_eq[j].iszero() and gemask[i] == eqmask[j]:
ratio = self.exprs_ge[i].get_ratio(self.exprs_eq[j], skip_simplify = True)
if ratio is None:
continue
self.exprs_ge[i] = Expr.zero()
gemask[i] = 0
did = True
break
for i in range(len(self.exprs_eq)):
if not self.exprs_eq[i].iszero():
for j in range(i):
if not self.exprs_eq[j].iszero() and eqmask[i] == eqmask[j]:
ratio = self.exprs_eq[i].get_ratio(self.exprs_eq[j], skip_simplify = True)
if ratio is None:
continue
self.exprs_eq[i] = Expr.zero()
eqmask[i] = 0
did = True
break
self.exprs_ge = [x for x in self.exprs_ge if not x.iszero()]
self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()]
for i in range(len(self.exprs_ge)):
if self.exprs_ge[i].isnonneg():
self.exprs_ge[i] = Expr.zero()
self.exprs_ge = [x for x in self.exprs_ge if not x.iszero()]
if True:
allzero = Expr.zero()
for i in range(len(self.exprs_ge)):
if self.exprs_ge[i].isnonpos():
for (a, c) in self.exprs_ge[i].terms:
allzero -= Expr.fromterm(a)
self.exprs_ge[i] = Expr.zero()
for i in range(len(self.exprs_eq)):
if self.exprs_eq[i].isnonpos() or self.exprs_eq[i].isnonneg():
for (a, c) in self.exprs_eq[i].terms:
allzero -= Expr.fromterm(a)
self.exprs_eq[i] = Expr.zero()
if not allzero.iszero():
allzero.simplify_quick(reg = reg)
allzero.sortIc()
#self.exprs_ge.append(allzero)
self.exprs_ge.insert(0, allzero)
if zero_group == 2:
pass
else:
for i in range(len(self.exprs_ge)):
if self.exprs_ge[i].isnonpos():
for (a, c) in self.exprs_ge[i].terms:
if zero_group == 1:
self.exprs_ge.append(-Expr.fromterm(a))
else:
self.exprs_eq.append(Expr.fromterm(a))
self.exprs_ge[i] = Expr.zero()
for i in range(len(self.exprs_eq)):
if self.exprs_eq[i].isnonpos() or self.exprs_eq[i].isnonneg():
for (a, c) in self.exprs_eq[i].terms:
if zero_group == 1:
self.exprs_ge.append(-Expr.fromterm(a))
else:
self.exprs_eq.append(Expr.fromterm(a))
self.exprs_eq[i] = Expr.zero()
self.exprs_ge = [x for x in self.exprs_ge if not x.iszero()]
self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()]
for x in self.exprs_ge:
x.simplify_mul(1)
for x in self.exprs_eq:
x.simplify_mul(2)
if self.imp_present():
t = self.imp_flippedonly()
t.simplify_quick(reg, zero_group)
self.exprs_gei = t.exprs_ge
self.exprs_eqi = t.exprs_eq
for x in self.exprs_ge:
if self.implies_ineq_quick(x, ">="):
x.setzero()
for x in self.exprs_eq:
if self.implies_ineq_quick(x, "=="):
x.setzero()
self.exprs_ge = [x for x in self.exprs_ge if not x.iszero()]
self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()]
if write_pf_enabled:
if self.tostring() != prevself.tostring():
# pf = ProofObj.from_region(prevself, c = "Simplify")
# pf += ProofObj.from_region(self, c = "Simplified as")
pf = ProofObj.from_region(("equiv", prevself, self), c = "Simplify:")
PsiOpts.set_setting(proof_add = pf)
return self
def iand_simplify_quick(self, other, skip_simplify = True):
did = False
for x in other.exprs_ge:
if not self.implies_ineq_cons_hash(x, ">="):
self.exprs_ge.append(x)
did = True
for x in other.exprs_eq:
if not self.implies_ineq_cons_hash(x, "=="):
self.exprs_eq.append(x)
did = True
if not skip_simplify and did:
self.simplify_quick(zero_group = 1)
return self
def split_ic2(self):
ge_insert = []
for i in range(len(self.exprs_ge)):
if self.exprs_ge[i].isnonpos():
t = Expr.zero()
for (a, c) in self.exprs_ge[i].terms:
if a.isic2():
ge_insert.append(-Expr.fromterm(a))
else:
t += Expr.fromterm(a) * c
self.exprs_ge[i] = t
for i in range(len(self.exprs_eq)):
if self.exprs_eq[i].isnonpos() or self.exprs_eq[i].isnonneg():
t = Expr.zero()
for (a, c) in self.exprs_eq[i].terms:
if a.isic2():
ge_insert.append(-Expr.fromterm(a))
else:
t += Expr.fromterm(a) * c
self.exprs_eq[i] = t
self.exprs_ge = ge_insert + self.exprs_ge
self.exprs_ge = [x for x in self.exprs_ge if not x.iszero()]
self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()]
def split(self):
ge_insert = []
for i in range(len(self.exprs_ge)):
if self.exprs_ge[i].isnonpos():
for (a, c) in self.exprs_ge[i].terms:
ge_insert.append(-Expr.fromterm(a))
self.exprs_ge[i] = Expr.zero()
for i in range(len(self.exprs_eq)):
if self.exprs_eq[i].isnonpos() or self.exprs_eq[i].isnonneg():
for (a, c) in self.exprs_eq[i].terms:
ge_insert.append(-Expr.fromterm(a))
self.exprs_eq[i] = Expr.zero()
self.exprs_ge = ge_insert + self.exprs_ge
self.exprs_ge = [x for x in self.exprs_ge if not x.iszero()]
self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()]
if self.imp_present():
t = self.imp_flippedonly()
t.split()
self.exprs_gei = t.exprs_ge
self.exprs_eqi = t.exprs_eq
@staticmethod
def intersection_interleave(rlist):
r = Region.universe()
r.exprs_ge = iutil.list_interleave(x.exprs_ge for x in rlist)
r.exprs_eq = iutil.list_interleave(x.exprs_eq for x in rlist)
r.exprs_gei = iutil.list_interleave(x.exprs_gei for x in rlist)
r.exprs_eqi = iutil.list_interleave(x.exprs_eqi for x in rlist)
r.aux = sum(iutil.list_interleave(x.aux for x in rlist), Comp.empty())
r.auxi = sum(iutil.list_interleave(x.auxi for x in rlist), Comp.empty())
return r
def symmetrized(self, symm_set, union = False, convexify = False, skip_simplify = False):
if symm_set is None:
return self
r = Region.universe()
if union:
r = RegionOp.empty()
cs = self.copy()
n = len(symm_set)
m = min(len(a) for a in symm_set)
tmpvar = [[] for i in range(n)]
for i in range(n):
for j in range(m):
if isinstance(symm_set[i][j], Comp):
tmpvar[i].append(Comp.rv("#TMPVAR" + str(i) + "_" + str(j)))
else:
tmpvar[i].append(Expr.real("#TMPVAR" + str(i) + "_" + str(j)))
cs.substitute(symm_set[i][j], tmpvar[i][j])
tcss = []
for p in itertools.permutations(range(n)):
tcs = cs.copy()
for i in range(n):
for j in range(m):
tcs.substitute(tmpvar[i][j], symm_set[p[i]][j])
tcss.append(tcs)
if union:
for tcs in tcss:
r |= tcs
else:
r = Region.intersection_interleave(tcss)
if convexify:
r = r.convexified_diag(skip_simplify = True)
if skip_simplify:
return r
else:
return r.simplified()
def simplify_bayesnet(self, reg = None, reduce_ic = False):
if isinstance(reg, RegionOp):
reg = reg.tosimple_noaux()
if reg is None:
reg = Region.universe()
icexpr = Expr.zero()
for x in self.exprs_ge:
if x.isnonpos():
icexpr += x
for x in self.exprs_eq:
if x.isnonpos():
icexpr += x
elif x.isnonneg():
icexpr -= x
if reg.isuniverse() and icexpr.iszero():
return
bnet = (reg & (icexpr >= 0)).get_bayesnet(skip_simplify = True)
for x in self.exprs_ge:
# Prevent circular simplification
if not x.isnonpos():
x.simplify_quick(bnet = bnet)
for x in self.exprs_eq:
# Prevent circular simplification
if not (x.isnonpos() or x.isnonneg()):
x.simplify_quick(bnet = bnet)
if reduce_ic:
for x in self.exprs_ge:
if x.isnonpos():
if bnet.check_ic(-x):
x.setzero()
for x in self.exprs_eq:
if x.isnonpos():
if bnet.check_ic(-x):
x.setzero()
elif x.isnonneg():
if bnet.check_ic(x):
x.setzero()
self.exprs_ge = [x for x in self.exprs_ge if not x.iszero()]
self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()]
self.iand_norename(bnet.get_region())
return self
def zero_exprs(self, avoid = None):
for x in self.exprs_ge:
if x is avoid:
continue
if x.isnonpos():
for y in x:
yield y
for x in self.exprs_eq:
if x is avoid:
continue
if x.isnonpos() or x.isnonneg():
for y in x:
yield y
else:
yield x
def empty_rvs(self):
r = Comp.empty()
for x in self.zero_exprs():
if len(x) == 1 and x.terms[0][1] != 0 and x.terms[0][0].ish():
r += x.terms[0][0].x[0]
return r
def simplify_pair(self, reg = None):
did = False
if isinstance(reg, RegionOp):
reg = reg.tosimple_noaux()
if reg is None:
reg = Region.universe()
for x, xs in [(a, ">=") for a in self.exprs_ge] + [(a, "==") for a in self.exprs_eq]:
xcomp = x.complexity()
for y in igen.pm(itertools.chain(self.zero_exprs(avoid = x), reg.zero_exprs())):
x2 = (x + y).simplified_quick()
x2comp = x2.complexity()
if x2comp < xcomp:
did = True
x.copy_(x2)
xcomp = x2comp
for z in reg.empty_rvs():
did = True
self.substitute(z, Comp.empty())
for z in self.empty_rvs():
did = True
self.substitute(z, Comp.empty())
self.exprs_eq.append(Expr.H(z))
if did:
self.simplify_quick()
return self
def simplify_redundant(self, reg = None, proc = None, full = True, quick = False, bnet = None):
write_pf_enabled = (PsiOpts.settings.get("proof_enabled", False)
and PsiOpts.settings.get("proof_step_simplify", False))
aux_relax = PsiOpts.settings.get("simplify_aux_relax", False)
if write_pf_enabled:
prevself = self.copy()
red_reg = Region.universe()
prev_write_pf_enabled = PsiOpts.settings.get("proof_enabled", False)
PsiOpts.settings["proof_enabled"] = False
if reg is None:
reg = Region.universe()
#if self.isregtermpresent():
# return self
allcompreal = self.allcompreal() + reg.allcompreal()
aux = self.aux
def preprocess(r):
if proc is not None:
r = proc(r)
if aux_relax:
r.aux += aux
r.aux_strengthen()
r.aux = Comp.empty()
return r
for i in range(len(self.exprs_ge) - 1, -1, -1):
if PsiOpts.is_timer_ended():
break
t = self.exprs_ge[i]
self.exprs_ge[i] = Expr.zero()
cs = self.imp_intersection_noaux() & reg
if not full:
cs.remove_notcontained(t.allcomp() + allcompreal)
cs = preprocess(cs)
if not (cs <= (t >= 0)).check_plain(quick = quick, bnet = bnet):
self.exprs_ge[i] = t
elif write_pf_enabled:
red_reg.exprs_ge.append(t)
self.exprs_ge = [x for x in self.exprs_ge if not x.iszero()]
for i in range(len(self.exprs_eq) - 1, -1, -1):
if PsiOpts.is_timer_ended():
break
t = self.exprs_eq[i]
self.exprs_eq[i] = Expr.zero()
cs = self.imp_intersection_noaux() & reg
if not full:
cs.remove_notcontained(t.allcomp() + allcompreal)
cs = preprocess(cs)
if not (cs <= (t == 0)).check_plain(quick = quick, bnet = bnet):
self.exprs_eq[i] = t
elif write_pf_enabled:
red_reg.exprs_eq.append(t)
self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()]
for i in range(len(self.exprs_gei) - 1, -1, -1):
if PsiOpts.is_timer_ended():
break
t = self.exprs_gei[i]
self.exprs_gei[i] = Expr.zero()
cs = self.imp_flippedonly_noaux() & reg
if not full:
cs.remove_notcontained(t.allcomp() + allcompreal)
cs = preprocess(cs)
if not (cs <= (t >= 0)).check_plain(quick = quick, bnet = bnet):
self.exprs_gei[i] = t
elif write_pf_enabled:
red_reg.exprs_gei.append(t)
self.exprs_gei = [x for x in self.exprs_gei if not x.iszero()]
for i in range(len(self.exprs_eqi) - 1, -1, -1):
if PsiOpts.is_timer_ended():
break
t = self.exprs_eqi[i]
self.exprs_eqi[i] = Expr.zero()
cs = self.imp_flippedonly_noaux() & reg
if not full:
cs.remove_notcontained(t.allcomp() + allcompreal)
cs = preprocess(cs)
if not (cs <= (t == 0)).check_plain(quick = quick, bnet = bnet):
self.exprs_eqi[i] = t
elif write_pf_enabled:
red_reg.exprs_eqi.append(t)
self.exprs_eqi = [x for x in self.exprs_eqi if not x.iszero()]
if False:
if self.imp_present():
t = self.imp_flippedonly()
t.simplify_redundant(reg)
self.exprs_gei = t.exprs_ge
self.exprs_eqi = t.exprs_eq
if write_pf_enabled:
if not red_reg.isuniverse():
pf = ProofObj.from_region(red_reg, c = "Remove redundant constraints")
pf = ProofObj.from_region(self, c = "Result")
PsiOpts.set_setting(proof_add = pf)
PsiOpts.settings["proof_enabled"] = prev_write_pf_enabled
return self
def simplify_symm(self, symm_set, quick = False):
"""Simplify a region, assuming symmetry among variables in symm_set.
"""
if isinstance(symm_set, Comp):
self.symm_sort(symm_set)
self.simplify_quick()
if not quick:
self.simplify_redundant(proc = lambda t: t.symmetrized(symm_set, skip_simplify = True))
def simplified_symm(self, symm_set, quick = False):
"""Simplify a region, assuming symmetry among variables in symm_set.
"""
r = self.copy()
r.simplify_symm(symm_set, quick)
return r
def var_mi_only(self, v):
return all(a.var_mi_only(v) for a in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi)
def aux_eq_towards(self, v, tgt, reg = None, quick = False, alt_cases = None):
if self.imp_present():
return False
if not self.var_mi_only(v):
return False
if isinstance(reg, RegionOp):
reg = reg.tosimple()
if reg is None:
reg = Region.universe()
sreg = self.copy_noaux() & reg
sreg_bnet = sreg.get_bayesnet()
ege = [a for a in self.exprs_ge if a.ispresent(v)]
eeq = [a for a in self.exprs_eq if a.ispresent(v)]
eget = [a.substituted(v, tgt) for a in ege]
eeqt = [a.substituted(v, tgt) for a in eeq]
for e0, e1 in zip(eeq, eeqt):
with PsiOpts(proof_enabled = False):
if not sreg.implies(e1 == 0, quick = quick, bnet = sreg_bnet):
return False
# oe0 = None
# oe1 = None
oe0s = []
oe1s = []
for i, (e0, e1) in enumerate(zip(ege, eget)):
with PsiOpts(proof_enabled = False):
if sreg.implies(e1 >= 0, quick = quick, bnet = sreg_bnet):
continue
if alt_cases is None and oe0s:
if sreg.implies(e0 <= oe0s[0], quick = quick, bnet = sreg_bnet):
oe0s.pop()
oe1s.pop()
elif sreg.implies(e0 >= oe0s[0], quick = quick, bnet = sreg_bnet):
continue
else:
return False
oe0s.append(e0)
oe1s.append(e1)
# if oe0 is None or sreg.implies(e0 <= oe0, quick = quick, bnet = sreg_bnet):
# # print(sreg)
# # print(str(e0) + " <= " + str(oe0))
# # print()
# oe0 = e0
# oe1 = e1
# continue
# if oe0 is not None and sreg.implies(e0 >= oe0, quick = quick, bnet = sreg_bnet):
# # print(sreg)
# # print(str(e0) + " >= " + str(oe0))
# # print()
# continue
# return False
if not oe0s:
return False
lowany = False
with PsiOpts(proof_enabled = False):
lowany = any(sreg.implies(oe1 <= 0, quick = quick, bnet = sreg_bnet) for oe1 in oe1s)
if alt_cases is None:
if not lowany:
return False
self.exprs_ge.remove(oe0s[0])
self.exprs_eq.append(oe0s[0])
return True
else:
for oe0, oe1 in zip(oe0s, oe1s):
t = self.copy()
t.exprs_ge = list(self.exprs_ge)
t.exprs_ge.remove(oe0)
t.exprs_eq.append(oe0)
t = t.copy()
alt_cases.append(t)
if not lowany:
t = self.copy()
t.substitute_aux(v, tgt)
alt_cases.append(t)
return True
def simplify_aux_eq(self, reg = None, quick = False, full = False):
if self.aux.isempty():
return self
if self.imp_present():
cs = self.consonly()
cs.simplify_aux_eq(reg, quick)
self.cons_shallow_copy_(cs)
return self
if isinstance(reg, RegionOp):
reg = reg.tosimple()
if reg is None:
reg = Region.universe()
# self.split()
did = True
for it in range(100):
if PsiOpts.is_timer_ended():
break
did = False
for a in (self.aux if it % 2 else reversed(self.aux)):
if PsiOpts.is_timer_ended():
break
ane = self.var_neighbors(a)
totry = []
if full:
totry = igen.subset(ane)
else:
totry = [Comp.empty()]
for tgt in totry:
if PsiOpts.is_timer_ended():
break
if self.aux_eq_towards(a, tgt, reg = reg, quick = quick):
did = True
break
if did:
break
if not did:
break
return self
def aux_eq_cases_inner(self, reg = None, quick = False, full = True, target = None):
if self.aux.isempty() or self.imp_present():
return [self.copy()]
if isinstance(reg, RegionOp):
reg = reg.tosimple()
if reg is None:
reg = Region.universe()
alt_cases = []
for a in self.aux:
if PsiOpts.is_timer_ended():
break
ane = self.var_neighbors(a)
totry = []
if target is not None:
totry = [target.inter(ane)]
elif full:
totry = igen.subset(ane)
else:
totry = [Comp.empty()]
for tgt in totry:
if PsiOpts.is_timer_ended():
break
if self.aux_eq_towards(a, tgt, reg = reg, quick = quick, alt_cases = alt_cases):
return sum((c.aux_eq_cases_inner(reg = reg, quick = quick, full = full, target = target) for c in alt_cases), [])
return [self.copy()]
def aux_eq_cases(self, reg = None, quick = False, full = True, target = None, skip_simplify = False):
t = self.aux_eq_cases_inner(reg = reg, quick = quick, full = full, target = target)
if len(t) == 1:
return t[0]
r = anyor(c.noaux() for c in t).exists(self.aux)
if not skip_simplify:
r = r.simplified()
return r
def aux_set(self, aux, skip_simplify = False, **kwargs):
"""
Obtain a region with the prescribed choices of auxiliaries. May enlarge the region.
Parameters
----------
aux : list of Comp
The choices of the new auxiliaries.
Returns
-------
Region
The new region.
"""
new_aux_sum = Comp.empty()
aux2 = []
t_new_aux = rv_seq("A", 1, len(aux) + 1)
for i, a in enumerate(aux):
if isinstance(a, tuple):
aux2.append(a)
new_aux_sum += a[0]
else:
aux2.append((t_new_aux[i], a))
new_aux_sum += t_new_aux[i]
discover_list = list(self.allcomprv() - self.aux) + list(self.allcomprealvar_exprlist())
t = self.discover(discover_list + aux2, **kwargs)
if not skip_simplify:
t = t.exists(new_aux_sum.inter(t.allcomprv())).simplified()
return t
def aux_reduced(self, new_aux = None, maxsize = 3, skip_simplify = False, aux_pairs = None, aux_force = None, score_fcn = None, **kwargs):
"""
Reduce the number of auxiliaries. May enlarge the region.
Parameters
----------
new_aux : Comp or int
The new auxiliaries. len(new_aux) will be the number of auxiliaries in the new region.
maxsize : int
The max number of old auxiliaries each new auxiliary corresponds to.
Returns
-------
Region
The new region.
"""
verbose = PsiOpts.settings.get("verbose_aux_reduced", False)
if self.imp_present():
return self.copy()
# cs = self.consonly()
# return cs.aux_reduced(new_aux = new_aux, maxsize = maxsize)
if new_aux is None:
new_aux = []
if isinstance(new_aux, int):
if new_aux == 0:
new_aux = []
elif new_aux == 1:
new_aux = [Comp.rv(self.name_avoid("A"))]
else:
t_new_aux = rv_seq("A", 1, new_aux + 1)
new_aux = []
index = IVarIndex()
self.record_to(index)
for a in t_new_aux:
cname = index.name_avoid(a.get_name())
c = Comp.rv(cname)
index.record(c)
new_aux.append(c)
n = len(new_aux)
new_aux_sum = sum(new_aux, Comp.empty())
discover_list = list(self.allcomprv() - self.aux) + list(self.allcomprealvar_exprlist())
selfnoaux = self.noaux()
index_self = IVarIndex()
selfnoaux.record_to(index_self)
progs = []
pts_outer = None
r = None
with PsiOpts(discover_max_facet = None):
iters = itertools.combinations(list(igen.subset(self.aux, minsize = 1, maxsize = maxsize)), n)
if score_fcn is not None:
tlist = []
for t in iters:
score = score_fcn(t)
if score is None:
continue
tlist.append((score, t))
tlist.sort()
iters = []
for i in range(len(tlist)):
score, t = tlist[i]
if isinstance(score, tuple) and score[1] < 0:
iters.append(t)
tlist.pop(i)
break
iters += [t for score, t in tlist]
for cauxs in iters:
if PsiOpts.is_timer_ended():
break
if aux_force is not None:
if n > 0 and not any(c.ispresent(aux_force) for c in cauxs):
continue
if aux_pairs is not None:
scores = [0, 0]
for c in cauxs:
for i, p in enumerate(aux_pairs):
for it in range(2):
if c.ispresent(p[it]):
scores[it] += 100 + i
if scores[0] < scores[1]:
if verbose:
print(", ".join(str(c) for c in cauxs) + " pair skipped")
continue
if r is not None:
bad = True
for cor in itertools.permutations(range(n)):
r2 = r.copy()
r2.remove_notpresent(new_aux_sum)
for i in range(n):
r2.substitute_aux(new_aux[i], cauxs[cor[i]])
if selfnoaux.implies_saved(r2.noaux(), index_self, progs):
cauxs = tuple([cauxs[cor[i]] for i in range(n)])
bad = False
break
if bad:
if verbose:
print(", ".join(str(c) for c in cauxs) + " subset fail")
continue
if PsiOpts.is_timer_ended():
break
if verbose:
print(", ".join(str(c) for c in cauxs) + " discover...")
new_pts_outer = []
# t = self.discover(discover_list + list(zip(new_aux, cauxs)), init_pts_outer = pts_outer, pts_outer = new_pts_outer)
t = self.discover(discover_list + list(zip(new_aux, cauxs)), **kwargs)
# r.append(t)
if PsiOpts.is_timer_ended():
break
if verbose:
print(t)
print()
if not skip_simplify:
t = t.exists(new_aux_sum.inter(t.allcomprv())).simplified().noaux()
if verbose:
print("Simplified:")
print(t)
print()
pts_outer = new_pts_outer
r = t
# if len(r) == 1:
# r = r[0]
# else:
# r = anyor(r)
if r is None:
return self.copy()
r = r.exists(new_aux_sum.inter(r.allcomprv()))
return r
def aux_hull_towards(self, v, tgt, reg = None, quick = False):
# This is needed for e.g. Gray-Wyner network
if self.imp_present():
return False
if not self.var_mi_only(v):
return False
if isinstance(reg, RegionOp):
reg = reg.tosimple()
if reg is None:
reg = Region.universe()
ege = [a for a in self.exprs_ge if a.ispresent(v)]
eeq = [a for a in self.exprs_eq if a.ispresent(v)]
eget = [a.substituted(v, tgt) for a in ege]
eeqt = [a.substituted(v, tgt) for a in eeq]
for e0, e1 in zip(eeq, eeqt):
with PsiOpts(proof_enabled = False):
if not reg.implies(e0 == e1, quick = quick):
return False
sns = [[False] * len(ege), [False] * len(ege)]
for i, (e0, e1) in enumerate(zip(ege, eget)):
with PsiOpts(proof_enabled = False):
sns[0][i] = reg.implies(e0 <= e1, quick = quick)
sns[1][i] = reg.implies(e0 >= e1, quick = quick)
for i, (e0, e1, s0, s1) in enumerate(zip(ege, eget, sns[0], sns[1])):
if s0 and s1:
continue
if not (s0 or s1):
continue
did = False
lvls = [0] * len(ege)
inc = Expr.zero()
if s0:
inc = e1 - e0
lvls[i] = 1
did = True
else:
inc = e0 - e1
lvls[i] = -1
inc.simplify_quick()
# print("TRY " + str(e0) + " " + str(e1) + " " + str(inc))
bad = False
for iz, (ez0, ez1, sz0, sz1) in enumerate(zip(ege, eget, sns[0], sns[1])):
if iz == i:
continue
if sz0 and sz1:
continue
ezdiff = (ez1 - ez0).simplified_quick()
if sz0:
with PsiOpts(proof_enabled = False):
if reg.implies(ezdiff >= inc, quick = quick):
lvls[iz] = 1
did = True
else:
lvls[iz] = 0
else:
with PsiOpts(proof_enabled = False):
if reg.implies(ezdiff >= -inc, quick = quick):
lvls[iz] = -1
else:
bad = True
break
# print(" TO " + str(ez0) + " " + str(ez1) + " " + str(lvls[iz]))
if bad or not did:
continue
t = Expr.real("#TMPVAR")
for iz, (ez0, ez1, sz0, sz1) in enumerate(zip(ege, eget, sns[0], sns[1])):
ez0 += lvls[iz] * t
self.iand_norename(t >= 0)
self.iand_norename(t <= inc)
# print(self)
with PsiOpts(simplify_aux_hull = False, simplify_aux_hull_lower_complexity = False, proof_enabled = False):
self.eliminate(t)
return True
return False
def cons_shallow_copy_(self, other):
self.exprs_ge = other.exprs_ge
self.exprs_eq = other.exprs_eq
self.aux = other.aux
return self
def simplify_aux_hull(self, reg = None, quick = False, lower_complexity = False):
if self.aux.isempty():
return self
if self.imp_present():
cs = self.consonly()
cs.simplify_aux_hull(reg, quick, lower_complexity)
self.cons_shallow_copy_(cs)
return self
if isinstance(reg, RegionOp):
reg = reg.tosimple()
if reg is None:
reg = Region.universe()
orig_self = None
if lower_complexity:
orig_self = self.copy()
did = True
for it in range(100):
if PsiOpts.is_timer_ended():
break
did = False
for a in (self.aux if it % 2 else reversed(self.aux)):
if PsiOpts.is_timer_ended():
break
ane = self.var_neighbors(a)
for tgt in igen.subset(ane):
if self.aux_hull_towards(a, tgt, reg = reg, quick = quick):
did = True
break
if did:
break
if not did:
break
if lower_complexity:
self.simplify_quick()
if self.complexity() >= orig_self.complexity():
self.copy_(orig_self)
return self
def simplify_aux_combine(self, reg = None, r_did = None):
if self.aux.isempty():
return self
if self.imp_present():
cs = self.consonly()
cs.simplify_aux_combine(reg, r_did)
self.cons_shallow_copy_(cs)
return self
if reg is None:
reg = Region.universe()
did = True
while did:
if self.aux.isempty():
break
if PsiOpts.is_timer_ended():
break
selfplain = self.copy()
selfplain.aux = Comp.empty()
selfplain &= reg
did = False
for i in range(len(self.aux)):
if PsiOpts.is_timer_ended():
break
for j in range(i + 1, len(self.aux)):
if PsiOpts.is_timer_ended():
break
cs = self.copy()
cs.aux = Comp.empty()
cs.substitute(self.aux[i], self.aux[i] + self.aux[j])
cs.substitute(self.aux[j], self.aux[i] + self.aux[j])
with PsiOpts(proof_enabled = False):
if selfplain.implies(cs):
with PsiOpts(meta_subs_criteria = "eqtype"):
self.substitute(self.aux[j], self.aux[i])
if r_did is not None:
r_did[0] = True
did = True
break
if did:
break
if did:
continue
for i in range(len(self.aux)):
if PsiOpts.is_timer_ended():
break
for j in range(len(self.aux)):
if i == j:
continue
if PsiOpts.is_timer_ended():
break
cs = self.copy()
cs.aux = Comp.empty()
cs.substitute(self.aux[j], Comp.empty())
cs.substitute(self.aux[i], self.aux[i] + self.aux[j])
with PsiOpts(proof_enabled = False):
if selfplain.implies(cs):
with PsiOpts(meta_subs_criteria = "eqtype"):
self.substitute(self.aux[j], Comp.empty())
if r_did is not None:
r_did[0] = True
did = True
break
if did:
break
return self
def simplify_aux(self, reg = None, cases = None, r_did = None):
if self.aux.isempty():
return self
if self.imp_present():
cs = self.consonly()
cs.simplify_aux(reg, cases, r_did)
self.cons_shallow_copy_(cs)
return self
leaveone = None
if cases is not None:
leaveone = True
if reg is None:
reg = Region.universe()
did = True
while did:
if self.aux.isempty():
break
did = False
index = IVarIndex()
self.record_to(index)
reg.record_to(index)
taux = self.aux.copy()
taux2 = taux.copy()
tauxi = self.auxi.copy()
self.aux = Comp.empty()
self.auxi = Comp.empty()
# print(self)
for a in taux:
if PsiOpts.is_timer_ended():
break
a2 = Comp.rv(index.name_avoid(a.get_name()))
cs = (self.consonly().substituted(a, a2) & reg) >> self.exists(a)
#hint_aux_avoid = self.get_aux_avoid_list()
hint_aux_avoid = [(a, a2)]
# print(cs)
with PsiOpts(proof_enabled = False):
for rr in cs.check_getaux_inplace_gen(hint_aux_avoid = hint_aux_avoid, leaveone = leaveone):
stype = iutil.signal_type(rr)
with PsiOpts(meta_subs_criteria = "eqtype"):
if stype == "":
# print(rr)
ar = a.copy()
Comp.substitute_list(ar, rr)
if ar == a:
continue
# print(self)
self.substitute(a, ar)
# print(self)
taux2 -= a
if r_did is not None:
r_did[0] = True
did = True
break
elif stype == "leaveone" and cases is not None:
ar = a.copy()
Comp.substitute_list(ar, rr[1])
if ar == a:
continue
tr = self.copy()
# tr.iand_norename(rr[2] <= 0)
tr.substitute(a, ar)
tr.aux = taux2 - a
tr.auxi = tauxi.copy()
tr.simplify_quick()
tr.simplify_aux(reg, cases)
cases.append(tr)
self.iand_norename(rr[2] >= 0)
if r_did is not None:
r_did[0] = True
did = True
self.aux = taux2
self.auxi = tauxi
if did:
self.simplify_quick()
return self
def simplify_imp(self, reg = None):
if not self.imp_present():
return
def simplify_expr_exhaust(self):
for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi:
x.simplify_exhaust()
return self
def expanded_cases_reduce(self, reg = None, skip_simplify = False):
r = self.copy()
cases = []
r.simplify_aux(reg, cases)
if len(cases) == 0:
return r
else:
r2 = RegionOp.union([r] + cases)
if not skip_simplify:
return r2.simplified()
else:
return r2
def expanded_cases(self, reg = None, leaveone = True, skip_simplify = False):
if self.aux.isempty():
return None
if self.imp_present():
return None
cases = []
if isinstance(reg, RegionOp):
reg = reg.tosimple()
if reg is None:
reg = Region.universe()
did = False
index = IVarIndex()
self.record_to(index)
reg.record_to(index)
# taux = self.aux.copy()
# taux2 = taux.copy()
# tauxi = self.auxi.copy()
# self.aux = Comp.empty()
# self.auxi = Comp.empty()
rmap = index.calc_rename_map(self.aux)
cs2 = self.copy()
cs2.rename_map(rmap)
s2 = self.consonly()
s2.aux = Comp.empty()
cs = (s2 & reg) >> cs2
hint_aux_avoid = cs.get_aux_avoid_list()
with PsiOpts(forall_multiuse_numsave = 0):
for rr in cs.check_getaux_inplace_gen(hint_aux_avoid = hint_aux_avoid, leaveone = leaveone):
stype = iutil.signal_type(rr)
if stype == "":
continue
print(rr)
cs3 = cs2.copy()
Comp.substitute_list(cs3, rr, isaux = True)
cs3.eliminate(self.aux.inter(cs3.allcomp()))
cases.append(cs3)
elif stype == "leaveone":
print(str(rr[1]) + " " + str(rr[2]))
cs3 = cs2.copy()
Comp.substitute_list(cs3, rr[1], isaux = True)
cs3.eliminate(self.aux.inter(cs3.allcomp()))
cases.append(cs3)
cs2.iand_norename(rr[2] >= 0)
if PsiOpts.is_timer_ended():
break
cs3 = cs2.copy()
cs3.rename_map({b: a for a, b in rmap.items()})
cases.append(cs3)
if len(cases) == 1:
return cases[0]
else:
r2 = RegionOp.union()
for a in cases:
r2 |= a
if not skip_simplify:
return r2.simplified()
else:
return r2
def simplify_aux_empty(self, reg = None, skip_simplify = False):
if self.aux.isempty():
return self
if self.imp_present():
cs = self.consonly()
cs.simplify_aux_empty(reg, skip_simplify)
self.cons_shallow_copy_(cs)
return self
if isinstance(reg, RegionOp):
reg = reg.tosimple()
if reg is None:
reg = Region.universe()
cs = self.copy()
cs.aux = Comp.empty()
cs = cs & reg
index_self = IVarIndex()
cs.record_to(index_self)
progs = []
for taux in igen.subset(self.aux, minsize = 2):
cs2 = self.copy()
cs2.aux = Comp.empty()
cs2.substitute(taux, Comp.empty())
# if cs.implies(cs2):
if cs.implies_saved(cs2, index_self, progs):
with PsiOpts(meta_subs_criteria = "eqtype"):
self.substitute_aux(taux, Comp.empty())
if not skip_simplify:
self.simplify_quick()
self.simplify_aux_empty(reg, skip_simplify)
return self
return self
def simplify_aux_recombine(self, reg = None, skip_simplify = False, r_did = None):
if self.aux.isempty():
return self
if self.imp_present():
cs = self.consonly()
cs.simplify_aux_recombine(reg, skip_simplify, r_did)
self.cons_shallow_copy_(cs)
return self
if isinstance(reg, RegionOp):
reg = reg.tosimple()
if reg is None:
reg = Region.universe()
did = False
index = IVarIndex()
self.record_to(index)
reg.record_to(index)
rmap = index.calc_rename_map(self.aux)
cs2 = self.copy()
cs2.rename_map(rmap)
s2 = self.consonly()
s2.aux = Comp.empty()
cs = (s2 & reg) >> cs2
hint_aux_avoid = cs.get_aux_avoid_list()
cmin = (len(self.aux), -len(self.aux))
minlist = None
with PsiOpts(auxsearch_leaveone_add_ineq = False):
for rr in cs.check_getaux_inplace_gen(hint_aux_avoid = hint_aux_avoid):
stype = iutil.signal_type(rr)
if stype == "":
# print(rr)
cnaux = 0
csize = 0
clist = []
cvar = Comp.empty()
for a in cs2.aux:
b = a.copy()
Comp.substitute_list(b, rr)
csize += len(b)
cvar += b
clist.append(b)
for a in self.aux:
if cvar.ispresent(a):
cnaux += 1
t = (cnaux, -csize)
if t < cmin:
cmin = t
minlist = clist
if PsiOpts.is_timer_ended():
break
if minlist is None:
return self
prevaux = self.aux.copy()
self.rename_map(rmap)
selfaux = self.aux.copy()
with PsiOpts(meta_subs_criteria = "eqtype"):
for a, b in zip(selfaux, minlist):
self.substitute_aux(a, b)
self.eliminate(prevaux.inter(self.allcomp()))
if r_did is not None:
r_did[0] = True
if not skip_simplify:
self.simplify_quick()
return self
def sort(self):
for x in self.exprs_ge + self.exprs_eq + self.exprs_gei + self.exprs_eqi:
x.sort()
self.exprs_ge.sort(key = lambda x: x.sorting_tuple_eqn())
self.exprs_eq.sort(key = lambda x: x.sorting_tuple_eqn())
self.exprs_gei.sort(key = lambda x: x.sorting_tuple_eqn())
self.exprs_eqi.sort(key = lambda x: x.sorting_tuple_eqn())
def simplify(self, reg = None, zero_group = 0, **kwargs):
"""Simplify a region in place
Optional argument reg with constraints assumed to be true
zero_group = 2: group all nonnegative terms as a single inequality
"""
if kwargs:
r = None
with PsiOpts(**{"simplify_" + key: val for key, val in kwargs.items()}):
r = self.simplify(reg, zero_group)
return r
if not PsiOpts.settings.get("simplify_enabled", False):
return self
if reg is None:
reg = Region.universe()
nit = PsiOpts.settings.get("simplify_num_iter", 1)
if PsiOpts.settings.get("simplify_regterm", False):
self.simplify_regterm(reg)
with PsiOpts(simplify_regterm = False):
for it in range(nit):
self.simplify_quick(reg, zero_group)
if not PsiOpts.settings.get("simplify_quick", False):
if PsiOpts.settings.get("simplify_remove_missing_aux", False):
self.remove_missing_aux()
if PsiOpts.settings.get("simplify_aux_commonpart", False):
self.simplify_aux_commonpart(reg,
maxlen = PsiOpts.settings.get("simplify_aux_xor_len", False))
if PsiOpts.settings.get("simplify_aux_empty", False):
self.simplify_aux_empty()
# if PsiOpts.settings.get("simplify_aux_combine", False):
# self.simplify_aux_combine(reg)
if PsiOpts.settings.get("simplify_aux_recombine", False):
self.simplify_aux_recombine()
elif PsiOpts.settings.get("simplify_aux", False):
self.simplify_aux()
r_did = [False]
if PsiOpts.settings.get("simplify_aux_combine", False):
self.simplify_aux_combine(reg, r_did)
if r_did[0]:
if PsiOpts.settings.get("simplify_aux_recombine", False):
self.simplify_aux_recombine()
elif PsiOpts.settings.get("simplify_aux", False):
self.simplify_aux()
# if PsiOpts.settings.get("simplify_bayesnet", False):
# self.simplify_bayesnet(reg)
if PsiOpts.settings.get("simplify_redundant", False):
self.simplify_redundant(reg, full = PsiOpts.settings.get("simplify_redundant_full", False))
if PsiOpts.settings.get("simplify_aux_hull", False):
self.simplify_aux_hull(reg)
elif PsiOpts.settings.get("simplify_aux_hull_lower_complexity", False):
self.simplify_aux_hull(reg, quick = True, lower_complexity = True)
if PsiOpts.settings.get("simplify_aux_eq", False):
self.simplify_aux_eq(reg)
if PsiOpts.settings.get("simplify_bayesnet", False):
self.simplify_bayesnet(reg)
if PsiOpts.settings.get("simplify_expr_exhaust", False):
self.simplify_expr_exhaust()
if PsiOpts.settings.get("simplify_pair", False):
self.simplify_pair(reg)
if PsiOpts.settings.get("simplify_remove_missing_aux", False):
self.remove_missing_aux()
if PsiOpts.settings.get("simplify_aux_strengthen", False):
self.aux_strengthen()
if PsiOpts.settings.get("simplify_sort", False):
self.sort()
return self
def simplify_union(self, reg = None):
self.simplify(reg)
return self
def simplified_quick(self, reg = None, zero_group = 0):
"""Returns the simplified region
Optional argument reg with constraints assumed to be true
zero_group = 2: group all nonnegative terms as a single inequality
"""
if reg is None:
reg = Region.universe()
r = self.copy()
r.simplify_quick(reg, zero_group)
return r
def simplified(self, reg = None, zero_group = 0, **kwargs):
"""Returns the simplified region
Optional argument reg with constraints assumed to be true
zero_group = 2: group all nonnegative terms as a single inequality
"""
if reg is None:
reg = Region.universe()
r = self.copy()
r.simplify(reg, zero_group, **kwargs)
return r
def remove_trivial(self):
self.exprs_gei = [x for x in self.exprs_gei if not x.isnonneg()]
self.exprs_eqi = [x for x in self.exprs_eqi if not (x.isnonneg() and x.isnonpos())]
self.exprs_ge = [x for x in self.exprs_ge if not x.isnonneg()]
self.exprs_eq = [x for x in self.exprs_eq if not (x.isnonneg() and x.isnonpos())]
def removed_trivial(self):
r = self.copy()
r.remove_trivial()
return r
def get_sum_seq(self, prefer_short = True, simplify_exhaust = False, target = None, bnet = None):
exprs = list(self.exprs_ge + self.exprs_eq)
c = Expr.zero()
r0 = []
r1 = []
exprs_vis = [False] * len(exprs)
eseq = []
for it in range(len(exprs)):
minc = None
minx = None
mincomp = 1e20
mini = 0
for i, x in enumerate(exprs):
if exprs_vis[i]:
continue
t = None
with PsiOpts(proof_enabled = False):
t = (c + x).simplified_quick(bnet = bnet)
if simplify_exhaust:
t.simplify_break_hc()
# t = t.simplified_exhaust()
tcomp = t.complexity() * 2 - x.complexity()
if tcomp < mincomp:
mincomp = tcomp
minc = t
minx = x
mini = i
if not prefer_short:
break
r0.append(minx)
r1.append(minc)
c = minc
exprs_vis[mini] = True
eseq.append(mini)
# print(i)
# print(exprs)
# print(minx)
# print(minc)
if target is not None:
for i, (x, t) in enumerate(reversed(list(zip(r0, r1)))):
t.simplify_target(target, bnet = bnet)
if i * 2 > len(r0):
break
target = (target - x).simplified_quick(bnet = bnet)
if simplify_exhaust:
target.simplify_break_hc()
return (r0, r1)
def get_ic(self, include_ic = True, include_hc = False, skip_simplify = False):
cs = self
if not skip_simplify:
cs = self.simplified_quick(zero_group = 2)
# icexpr = Expr.zero()
# for x in cs.exprs_ge:
# if x.isnonpos():
# icexpr += x
# for x in cs.exprs_eq:
# if x.isnonpos():
# icexpr += x
# elif x.isnonneg():
# icexpr -= x
# return icexpr
exprs = []
r = Expr.zero()
for x in cs.exprs_ge:
if x.isnonpos():
exprs.append(x)
for x in cs.exprs_eq:
if x.isnonpos() or x.isnonneg():
exprs.append(x)
for x in exprs:
for a, c in x.terms:
if (include_ic and a.isic2()) or (include_hc and a.ishc()):
r += Expr.fromterm(a)
return r
def remove_ic(self):
exprs = []
r = Expr.zero()
for x in self.exprs_ge:
if x.isnonpos():
exprs.append(x)
for x in self.exprs_eq:
if x.isnonpos() or x.isnonneg():
exprs.append(x)
for x in exprs:
tterms = []
for a, c in x.terms:
if a.isic2():
r += Expr.fromterm(a)
else:
tterms.append((a, c))
x.terms = tterms
x.mhash = None
self.exprs_ge = [x for x in self.exprs_ge if not x.iszero()]
self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()]
return r
def allcomprv_sorted(self):
r = self.allcomprv()
order = None
if self.meta is not None:
order = self.meta.get("rv_order", None)
if order is not None:
order = iutil.ensure_comp(order)
if order is not None:
r = order.inter(r) + r
return r
def get_bayesnet(self, vs = None, roots = None, semigraphoid_iter = None, get_list = False, skip_simplify = False):
"""Return a Bayesian network containing the conditional independence
conditions in this region
"""
if semigraphoid_iter is None:
semigraphoid_iter = PsiOpts.settings["bayesnet_semigraphoid_iter"]
icexpr = None
if vs is not None:
icexpr = self.completed_semigraphoid_ic(vs = vs, max_iter = semigraphoid_iter)
else:
icexpr = self.get_ic(skip_simplify = skip_simplify)
if semigraphoid_iter > 0:
icexpr += self.completed_semigraphoid_ic(max_iter = semigraphoid_iter)
add_var = self.allcomprv_sorted().inter(icexpr.allcomprv())
if get_list:
return BayesNet.from_ic_list(icexpr, roots = roots, add_var = add_var)
else:
return BayesNet.from_ic(icexpr, roots = roots, add_var = add_var).tsorted()
def graph(self, vs = None, **kwargs):
"""Return the Bayesian network among the random variables as a
graphviz digraph that can be displayed in the console.
"""
return self.get_bayesnet(vs = vs).graph(**kwargs)
def get_bayesnet_imp(self, skip_simplify = False, add_hc = True):
"""Return a Bayesian network containing the conditional independence
conditions in this region
"""
cs = self
if not skip_simplify:
cs = self.simplified_quick(zero_group = 2)
icexpr = Expr.zero()
for x in cs.exprs_gei:
if x.isnonpos():
icexpr += x
for x in cs.exprs_eqi:
if x.isnonpos():
icexpr += x
elif x.isnonneg():
icexpr -= x
return BayesNet.from_ic(icexpr, add_var = self.allcomprv(), add_hc = add_hc).tsorted()
def get_markov(self):
"""Get Markov chains as a list of lists.
"""
bnets = self.get_bayesnet(get_list = True)
r = []
for bnet in bnets:
r += bnet.get_markov()
return r
def table(self, *args, skip_cons = False, plot = True, use_latex = None, **kwargs):
"""Plot the information diagram as a Karnaugh map.
"""
if use_latex is None:
use_latex = PsiOpts.settings["venn_latex"]
imp_r = self.imp_flippedonly_noaux()
if not imp_r.isuniverse():
return imp_r.table(self.consonly(), *args, skip_cons, plot, **kwargs)
ceps = PsiOpts.settings["eps"]
index = IVarIndex()
cmodel = None
for a in args:
if isinstance(a, (ConcModel, LinearProg, SharedRVModel)):
cmodel = a
comprv = Comp.empty()
for a in args:
if isinstance(a, Comp):
a.record_to(index)
comprv += a
self.get_bayesnet().allcomp().record_to(index)
if cmodel is not None:
if isinstance(cmodel, ConcModel):
cmodel.get_bayesnet().allcomp().record_to(index)
if isinstance(cmodel, SharedRVModel):
cmodel.index_rv.comprv.record_to(index)
for a in args:
if isinstance(a, Expr) or isinstance(a, Region):
a.record_to(index)
cs = self
if cmodel is not None:
if isinstance(cmodel, ConcModel):
cs = cs & cmodel.get_bayesnet().get_region()
# cs = cs.imp_flipped()
cs.record_to(index)
if comprv.isempty():
comprv = index.comprv
r = CellTable(comprv)
progs = []
# progs.append(self.init_prog(index, lptype = LinearProgType.H))
# progs.append(cs.init_prog(index))
creg = Region.universe()
for mask in range(1, 1 << len(comprv)):
# ch = H(comprv.from_mask(mask) | comprv.from_mask(((1 << len(comprv)) - 1) ^ mask))
ch = I(alland(comprv[i] for i in range(len(comprv)) if mask & (1 << i))
| comprv.from_mask(((1 << len(comprv)) - 1) ^ mask))
# ispos = ch.isnonneg() or cs.implies_ineq_prog(index, progs, ch, ">=")
# isneg = cs.implies_ineq_prog(index, progs, -ch, ">=")
# ispos = ch.isnonneg() or cs.implies_saved(ch >= 0, index, progs)
# isneg = cs.implies_saved(ch <= 0, index, progs)
ispos = ch.isnonneg() or cs.implies(ch >= 0)
isneg = cs.implies(ch <= 0)
# print(str(ch) + " " + str(ispos) + " " + str(isneg))
if ispos and isneg:
r.set_attr(mask, "enabled", False)
if not skip_cons:
creg.iand_norename(ch == 0)
elif ispos:
if cmodel is None:
r.set_attr(mask, "ispos", True)
if not skip_cons and not ch.isnonneg():
creg.iand_norename(ch >= 0)
elif isneg:
if cmodel is None:
r.set_attr(mask, "isneg", True)
if not skip_cons:
creg.iand_norename(ch <= 0)
if cmodel is not None:
r.set_attr(mask, "cval", cmodel[ch])
cnexpr = 0
cargs = []
if not skip_cons:
for cexpr in self.exprs_ge:
if not creg.implies(cexpr >= 0):
cargs.append(cexpr >= 0)
for cexpr in self.exprs_eq:
if not creg.implies(cexpr == 0):
cargs.append(cexpr == 0)
cargs += args
for a in cargs:
exprlist = []
if isinstance(a, Expr):
exprlist = [(a, None)]
elif isinstance(a, Region):
exprlist = [(b, ">=") for b in a.exprs_ge] + [(b, "==") for b in a.exprs_eq]
for b, sn in exprlist:
if not comprv.super_of(b.allcomprv()):
continue
if sn == ">=":
r.add_expr(b >= 0, None if cmodel is None else cmodel[b >= 0])
elif sn == "==":
r.add_expr(b == 0, None if cmodel is None else cmodel[b == 0])
else:
r.add_expr(b, None if cmodel is None else cmodel[b])
maskvals = [0.0 for mask in range(1 << len(comprv))]
for v, c in b.terms:
if v.get_type() == TermType.IC:
xmasks = [comprv.get_mask(x) for x in v.x]
zmask = comprv.get_mask(v.z)
for mask in range(1, 1 << len(comprv)):
if mask & zmask:
continue
if any(not(mask & xmask) for xmask in xmasks):
continue
maskvals[mask] += c
# print(mask)
for mask in range(1, 1 << len(comprv)):
if abs(maskvals[mask]) > ceps:
r.set_expr_val(mask, maskvals[mask])
# r.set_attr(mask, "val_" + str(cnexpr), maskvals[mask])
cnexpr += 1
# r.nexpr = cnexpr
if plot:
r.plot(use_latex = use_latex, **kwargs)
else:
return r
def venn(self, *args, style = None, **kwargs):
"""Plot the information diagram as a Venn diagram.
Can handle up to 5 random variables (uses Branko Grunbaum's Venn diagram for n=5).
"""
if style is None:
style = ""
style = "venn,blend," + style
return self.table(*args, style = style, **kwargs)
def sphere(self, entries = None, lib = None, ndiv = 480, plot = True, use_latex = None, **kwargs):
"""Plot the cone as a sphere.
"""
lib = PsiOpts.get_plotlib(lib)
if entries is None:
entries = list(self.reals)
if len(entries) > 3:
entries = entries[:3]
entries2 = []
for t in entries:
if isinstance(t, tuple):
entries2.append(t)
elif isinstance(t, Expr):
entries2.append((t, t))
entries = entries2
if len(entries) > 3:
raise ValueError("Cannot plot more than 3 dimensions.")
if len(entries) < 3:
raise ValueError("Requires at least 3 dimensions.")
if use_latex is None:
use_latex = PsiOpts.settings["venn_latex"]
cs = self
csindex = IVarIndex()
cs.record_to(csindex)
entriesindex = IVarIndex()
for a, b in entries:
b.record_to(entriesindex)
if len(csindex.comprv) or len(entriesindex.comprv) or not entriesindex.compreal.super_of(csindex.compreal):
tmpreals = real_array("#TMP", len(entries))
cs = cs.discover([(t, b) for t, (a, b) in zip(tmpreals, entries)])
entries = [(a, t) for t, (a, b) in zip(tmpreals, entries)]
u, v = numpy.meshgrid(numpy.linspace(0, numpy.pi, ndiv // 2 + 1), numpy.linspace(0, 2*numpy.pi, ndiv + 1))
x = numpy.sin(u) * numpy.cos(v)
y = numpy.sin(u) * numpy.sin(v)
z = numpy.cos(u)
model = SharedRVModel()
# for t in entries:
# model[t] = 0.0
model[entries[0][1]] = x
model[entries[1][1]] = y
model[entries[2][1]] = z
isin = cs.evalcheck_vec(model)
col_back = [0.85, 0.85, 0.85]
col_front = [0, 0, 1]
col_line = [1, 0, 0]
col_line2 = [1, 0.6, 0.3]
# def col(x, y, z):
# model[entries[0]] = x
# model[entries[1]] = y
# model[entries[2]] = z
# if cs.evalcheck(model):
# return col_front
# else:
# return col_back
# def isin(x, y, z):
# model[entries[0]] = x
# model[entries[1]] = y
# model[entries[2]] = z
# return int(cs.evalcheck(model))
if lib == "matplotlib":
from mpl_toolkits.mplot3d import Axes3D
cols = numpy.zeros(x.shape + (3,))
for i in range(x.shape[0]):
for j in range(x.shape[1]):
if isin[i, j]:
cols[i, j, :] = col_front
else:
cols[i, j, :] = col_back
fig = plt.figure()
ax = fig.add_subplot(111, projection = "3d")
ax.plot_surface(x, y, z, linewidth = 0, antialiased = False,
facecolors = cols, shade = False, rstride = 1, cstride = 1)
# ax.plot_wireframe(x, y, z, linewidth = 1, antialiased = True)
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.set_zlim(-1, 1)
ax.set_box_aspect([1,1,1])
ax.set_xlabel(str(entries[0][0]))
ax.set_ylabel(str(entries[1][0]))
ax.set_zlabel(str(entries[2][0]))
ax.set_proj_type("ortho")
if plot:
plt.show()
else:
return ax
if lib == "plotly":
def col2str(c):
return "rgb(" + ",".join(str(round(x * 255)) for x in c) + ")"
cols = isin
colscale = [[0, col2str(col_back)], [1, col2str(col_front)]]
def greatcircle(d, col):
rad = 1 + 1e-5
ndiv2 = 600
angs = numpy.linspace(0, 2*numpy.pi, ndiv2 + 1)
d = numpy.array(d)
d = d / numpy.linalg.norm(d)
a = None
for i in range(3):
a1 = numpy.cross(d, [int(j==i) for j in range(3)])
if a is None or numpy.linalg.norm(a1) > numpy.linalg.norm(a):
a = a1
a = a / numpy.linalg.norm(a)
b = numpy.cross(d, a)
b = b / numpy.linalg.norm(b)
return plotly.graph_objects.Scatter3d(
x = (numpy.cos(angs)*a[0] + numpy.sin(angs)*b[0]) * rad,
y = (numpy.cos(angs)*a[1] + numpy.sin(angs)*b[1]) * rad,
z = (numpy.cos(angs)*a[2] + numpy.sin(angs)*b[2]) * rad,
mode = "lines",
line = dict(color = col2str(col))
)
fig = plotly.graph_objects.Figure(data=[
plotly.graph_objects.Surface(x=x, y=y, z=z, surfacecolor=cols, colorscale=colscale, showscale=False,
lighting=plotly.graph_objects.surface.Lighting(ambient=1, diffuse=0),
contours = dict(x = dict(highlight = False), y = dict(highlight = False), z = dict(highlight = False))),
greatcircle([1, 0, 0], col_line),
greatcircle([0, 1, 0], col_line),
greatcircle([0, 0, 1], col_line),
greatcircle([1, 1, 0], col_line2),
greatcircle([1, -1, 0], col_line2),
greatcircle([0, 1, 1], col_line2),
greatcircle([0, 1, -1], col_line2),
greatcircle([1, 0, 1], col_line2),
greatcircle([-1, 0, 1], col_line2)
])
fig.layout.scene.camera.projection.type = "orthographic"
fig.update_layout(scene = dict(
# xaxis_title = str(entries[0][0]),
# yaxis_title = str(entries[1][0]),
# zaxis_title = str(entries[2][0])
xaxis = dict(title = str(entries[0][0]), showspikes = False),
yaxis = dict(title = str(entries[1][0]), showspikes = False),
zaxis = dict(title = str(entries[2][0]), showspikes = False)
), showlegend = False)
# fig.update_xaxes(showspikes = False)
# fig.update_yaxes(showspikes = False)
# fig.update_zaxes(showspikes = False)
if plot:
fig.show()
else:
return fig
def eliminate_term_eq(self, w):
ee = None
if ee is None:
for x in self.exprs_eqi:
c = x.get_coeff(w)
if abs(c) > PsiOpts.settings["eps"]:
ee = (x * (-1.0 / c)).removed_term(w)
x.setzero()
break
if ee is None:
for x in self.exprs_eq:
c = x.get_coeff(w)
if abs(c) > PsiOpts.settings["eps"]:
ee = (x * (-1.0 / c)).removed_term(w)
x.setzero()
break
if ee is not None:
self.substitute(Expr.fromterm(w), ee)
self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()]
self.exprs_eqi = [x for x in self.exprs_eqi if not x.iszero()]
return True
return False
def get_lb_ub_eq(self, w):
"""
Get lower bounds, upper bounds and equality constraints about w.
Parameters
----------
w : Expr
The real variable of interest.
Returns
-------
tuple
A tuple of 3 lists of Expr: Lower bounds, upper bounds and
equality constraints about w.
"""
if isinstance(w, Expr):
w = Term.fromcomp(w.allcomp())
el = []
er = []
ee = []
for x in self.exprs_ge:
c = x.get_coeff(w)
if abs(c) <= PsiOpts.settings["eps"]:
pass
elif c > 0:
er.append((x * (-1.0 / c)).removed_term(w))
else:
el.append((x * (-1.0 / c)).removed_term(w))
for x in self.exprs_eq:
c = x.get_coeff(w)
if abs(c) <= PsiOpts.settings["eps"]:
pass
else:
ee.append((x * (-1.0 / c)).removed_term(w))
return (er, el, ee)
def get_one_bound(self, w, max_term = 1):
er, el, ee = self.get_lb_ub_eq(w)
if not ee:
if len(er) == 1:
ee.append(er[0])
elif len(er) <= max_term and len(er) <= len(el):
ee.append(emax(*er))
if len(el) == 1:
ee.append(el[0])
elif len(el) <= max_term and len(el) <= len(er):
ee.append(emin(*el))
if not ee:
return None
if len(ee) == 1:
return ee[0].copy()
def tmp_complexity(x):
return (x.complexity(), len(repr(x)))
# return (x.isregtermpresent(), not x.ispresent("realvar"), x.complexity(), len(repr(x)))
r = None
rc = None
for x in ee:
c = tmp_complexity(x)
if r is None or c < rc:
rc = c
r = x
return r.copy()
def eliminate_term(self, w, forall = False, reg_record = None):
verbose = PsiOpts.settings.get("verbose_eliminate", False)
reg_record_r = None
if reg_record is not None:
reg_record_r = Region.universe()
reg_record.append((Expr.fromterm(w), reg_record_r))
el = []
er = []
ee = []
if verbose:
print("=========== Eliminate ===========")
print(w)
for x in self.exprs_gei:
c = x.get_coeff(w)
if abs(c) <= PsiOpts.settings["eps"]:
x.remove_term(w)
elif c > 0:
er.append((x * (-1.0 / c)).removed_term(w))
x.setzero()
else:
el.append((x * (-1.0 / c)).removed_term(w))
x.setzero()
for x in self.exprs_eqi:
c = x.get_coeff(w)
if abs(c) <= PsiOpts.settings["eps"]:
x.remove_term(w)
else:
ee.append((x * (-1.0 / c)).removed_term(w))
x.setzero()
elni = len(el)
erni = len(er)
eeni = len(ee)
if not forall and elni + erni + eeni > 0:
self.setuniverse()
return self
for x in self.exprs_ge:
c = x.get_coeff(w)
if abs(c) <= PsiOpts.settings["eps"]:
x.remove_term(w)
else:
if c > 0:
er.append((x * (-1.0 / c)).removed_term(w))
else:
el.append((x * (-1.0 / c)).removed_term(w))
if reg_record_r is not None:
reg_record_r.exprs_ge.append(x.copy())
x.setzero()
for x in self.exprs_eq:
c = x.get_coeff(w)
if abs(c) <= PsiOpts.settings["eps"]:
x.remove_term(w)
else:
ee.append((x * (-1.0 / c)).removed_term(w))
if reg_record_r is not None:
reg_record_r.exprs_ge.append(x.copy())
x.setzero()
if len(ee) > 0:
if eeni == 0:
for i in range(elni):
x = el[i]
for j in range(erni):
y = er[j]
self.exprs_gei.append(x - y)
for i in range(len(el)):
x = el[i]
if i < elni and 0 < eeni:
self.exprs_gei.append(x - ee[0])
else:
self.exprs_ge.append(x - ee[0])
for j in range(len(er)):
y = er[j]
if j < erni and 0 < eeni:
self.exprs_gei.append(ee[0] - y)
else:
self.exprs_ge.append(ee[0] - y)
for i in range(1, len(ee)):
x = ee[i]
if i < eeni:
self.exprs_eqi.append(x - ee[0])
else:
self.exprs_eq.append(x - ee[0])
else:
for i in range(len(el)):
x = el[i]
for j in range(len(er)):
y = er[j]
if i < elni and j < erni:
self.exprs_gei.append(x - y)
else:
self.exprs_ge.append(x - y)
if verbose:
print("=========== To ===========")
print(self)
return self
@staticmethod
def eliminate_reg_record_clean(reg_record, keep_terms, cross=False):
for a, b in reg_record:
if cross or not keep_terms.ispresent(a):
for a2, b2 in reg_record:
if not a.ispresent(a2) and b2.ispresent(a):
b2.iand_norename(b)
b2.eliminate(a)
def eliminate_toreal(self, w, forall = False):
verbose = PsiOpts.settings.get("verbose_eliminate_toreal", False)
if verbose:
print("========== elim real ========")
print(self)
print("========== var =========")
print(w)
reals = self.allcompreal()
self.iand_norename(self.get_prog_region(toreal = w, toreal_only = True))
self.imp_flip()
self.iand_norename(self.get_prog_region(toreal = w, toreal_only = True))
self.imp_flip()
self.remove_present(w)
reals = self.allcompreal() - reals
if verbose:
print("========== to =========")
print(self)
for a in reals.varlist:
self.eliminate_term(Term.fromcomp(Comp([a])), forall = forall)
self.simplify_quick()
if verbose:
print("========== elim " + str(a) + " =========")
print(self)
return self
def eliminate_toreal_rays(self, w):
cs = self.consonly().imp_flipped()
index = IVarIndex()
cs.record_to(index)
r = cs.init_prog(index, lptype = LinearProgType.H).get_region_elim_rays(w)
self.exprs_ge = r.exprs_ge
self.exprs_eq = r.exprs_eq
return self
def remove_aux(self, w):
self.aux -= w
self.auxi -= w
def remove_missing_aux(self):
#return
taux = self.aux
self.aux = Comp.empty()
tauxi = self.auxi
self.auxi = Comp.empty()
allcomp = self.allcomprv()
self.aux = taux.inter(allcomp)
self.auxi = tauxi.inter(allcomp)
@staticmethod
def get_allcomp(w):
if w is None:
return Comp.empty()
if isinstance(w, CompArray) or isinstance(w, ExprArray):
w = w.allcomp()
if isinstance(w, (tuple, list)):
w = sum((Region.get_allcomp(a) for a in w), Comp.empty())
if isinstance(w, Expr):
w = w.allcomp()
return w
def eliminate_ic(self, w):
self.iand_norename(self.completed_semigraphoid(self.allcomprv() - w))
self.split()
self.remove_relax(w)
return self
def eliminate(self, w, reg = None, toreal = False, forall = False, quick = False, method = "", reg_record = None):
"""Fourier-Motzkin elimination, in place.
w is the Expr object with the real variables to eliminate.
If w contains random variables, they will be treated as auxiliary RV.
"""
w = Region.get_allcomp(w)
# if isinstance(w, Comp):
# w = Expr.H(w)
if method == "real":
toreal = True
if ((not toreal and not forall and not self.auxi.isempty() and any(v.get_type() == IVarType.RV for v in w.allcomp()))
or (forall and any(v.get_type() == IVarType.REAL for v in w.allcomp()))
or (not forall and self.imp_present())):
return RegionOp.inter([self]).eliminate(w, reg, toreal, forall)
#self.simplify_quick(reg)
if toreal and PsiOpts.settings["eliminate_rays"]:
w2 = w
w = Comp.empty()
toelim = Comp.empty()
for v in w2.allcomp():
if toreal or v.get_type() == IVarType.REAL:
w += v
elif v.get_type() == IVarType.RV:
toelim += v
if not toelim.isempty():
self.eliminate_toreal_rays(toelim)
toelim = Comp.empty()
rvs = Comp.empty()
simplify_needed = False
for v in w.allcomp():
if v.get_type() == IVarType.REAL or toreal:
if simplify_needed:
if quick:
self.simplify_quick(reg)
else:
with PsiOpts(simplify_redundant_full = False,
simplify_aux_commonpart = False,
simplify_aux = False,
simplify_bayesnet = False):
self.simplify(reg)
if v.get_type() == IVarType.REAL:
self.eliminate_term(Term.fromcomp(v), forall = forall, reg_record = reg_record)
else:
self.eliminate_toreal(v, forall = forall)
simplify_needed = True
else:
rvs += v
if forall:
self.auxi += v
else:
self.aux += v
if method == "ic" or method == "ci":
if not forall and not rvs.isempty():
self.eliminate_ic(rvs)
simplify_needed = True
if simplify_needed:
if quick:
self.simplify_quick(reg)
else:
self.simplify(reg)
return self
def eliminate_quick(self, w, reg = None, toreal = False, forall = False, method = ""):
"""Fourier-Motzkin elimination, in place.
w is the Expr object with the real variables to eliminate.
If w contains random variables, they will be treated as auxiliary RV.
"""
return self.eliminate(w, reg = reg, toreal = toreal, forall = forall, quick = True, method = method)
def eliminated(self, w, reg = None, toreal = False, forall = False, method = ""):
"""Fourier-Motzkin elimination, return region after elimination.
w is the Expr object with the real variable to eliminate.
If w contains random variables, they will be treated as auxiliary RV.
"""
r = self.copy()
r = r.eliminate(w, reg, toreal, forall, method = method)
return r
def eliminated_quick(self, w, reg = None, toreal = False, forall = False, method = ""):
"""Fourier-Motzkin elimination, return region after elimination.
w is the Expr object with the real variable to eliminate.
If w contains random variables, they will be treated as auxiliary RV.
"""
r = self.copy()
r = r.eliminate_quick(w, reg, toreal, forall, method = method)
return r
def exists(self, w = None, reg = None, toreal = False, method = ""):
"""Alias of eliminated.
"""
if w is None:
w = self.allcomprv_noaux()
r = self.copy()
r = r.eliminate(w, reg, toreal, forall = False, method = method)
return r
def exists_quick(self, w = None, reg = None, toreal = False, method = ""):
"""Alias of eliminated_quick.
"""
if w is None:
w = self.allcomprv_noaux()
r = self.copy()
r = r.eliminate_quick(w, reg, toreal, forall = False, method = method)
return r
def forall(self, w = None, reg = None, toreal = False, method = ""):
"""Region of intersection for all variable w.
"""
if w is None:
w = self.allcomprv_noaux()
r = self.copy()
r = r.eliminate(w, reg, toreal, forall = True, method = method)
return r
def forall_quick(self, w = None, reg = None, toreal = False, method = ""):
"""Region of intersection for all variable w.
"""
if w is None:
w = self.allcomprv_noaux()
r = self.copy()
r = r.eliminate_quick(w, reg, toreal, forall = True, method = method)
return r
def rv_seq_avoid(self, w):
index = IVarIndex()
self.record_to(index)
w2 = []
for a in w:
b = Comp.rv(index.name_avoid(a.get_name()))
w2.append(b)
b.record_to(index)
return sum(w2, Comp.empty())
def unique(self, w = None):
"""The random variable w satisfying the constraints in this region is unique.
For both existence and uniqueness, use r.exists_unique(w) instead.
"""
if w is None:
w = self.allcomprv_noaux()
w2 = self.rv_seq_avoid(w)
return ((self & self.substituted(list(zip(w, w2)))) >> alland(equiv(a, b) for a, b in zip(w, w2))).forall(w+w2)
def exists_unique(self, w = None):
"""The random variable w satisfying the constraints in this region exists and is unique.
"""
# return self.exists(w) & self.unique(w)
if w is None:
w = self.allcomprv_noaux()
w2 = self.rv_seq_avoid(w)
return (self & (self.substituted(list(zip(w, w2))) >> alland(equiv(a, b) for a, b in zip(w, w2))).forall(w2)).exists(w)
def projected(self, w = None, reg = None, quick = False):
"""Project the region to real variables in the list w by eliminating all
other real variables. E.g. for a Region r with 3 real variables R1,R2,R3,
r.projected([R1, R2]) keeps R1,R2 and eliminates R3, and
r.projected(S == R1+R2+R3) projects the region to the diagonal S == R1+R2+R3
(i.e., introduces S and eliminates R1,R2,R3).
"""
if w is None:
w = []
if isinstance(w, Expr):
w = list(w)
if isinstance(w, Region):
w = [w]
compreal = self.allcomprealvar()
r = self.copy()
for a in w:
if isinstance(a, Expr):
compreal -= a.allcomp()
elif isinstance(a, Region):
a2 = a
if a2.isregtermpresent():
a2 = a2.flattened(minmax_elim = True)
if r.get_type() == RegionType.NORMAL and a2.get_type() == RegionType.NORMAL:
r.iand_norename(a2)
else:
r &= a2
# print(r)
if quick:
r = r.eliminate_quick(compreal, reg)
else:
r = r.eliminate(compreal, reg)
return r
def splice_rate(self, w0, w1):
"""Allow decreasing the real variable w0 (Expr) and increasing w1 (Expr) by the same ammount.
"""
if isinstance(w0, Expr):
w0 = [w0]
if isinstance(w1, Expr):
w1 = [w1]
t = Expr.real("#TMPVAR")
for a0 in w0:
self.substitute(a0, a0 + t)
for a1 in w1:
self.substitute(a1, a1 - t)
for a0 in w0:
self &= a0 >= 0
self &= t >= 0
self.eliminate(t)
return self
def spliced_rate(self, w0, w1):
"""Allow decreasing the real variable w0 (Expr) and increasing w1 (Expr) by the same ammount.
"""
r = self.copy()
r.splice_rate(w0, w1)
return r
def marginal_eliminate(self, w):
"""Set input, in place.
Denote the RV's in w (type Comp) as input variables, and the region
is the union over distributions of w.
"""
self.inp += w
return self
def marginal_exists(self, w):
"""Set input, return result.
Denote the RV's in w (type Comp) as input variables, and the region
is the union over distributions of w.
"""
r = self.copy()
r.marginal_eliminate(w)
return r
def marginal_forall(self, w):
"""Set input, return result.
Currently we do not differentiate between input distribution union
and intersection. May change in the future.
"""
r = self.copy()
r.marginal_eliminate(w)
return r
def kernel_eliminate(self, w):
"""Set output, in place.
Denote the RV's in w (type Comp) as output variables, and the region
is the union over channels leading to w with same marginal on w.
"""
self.oup += w
return self
def kernel_exists(self, w):
"""Set output, return result.
Denote the RV's in w (type Comp) as output variables, and the region
is the union over channels leading to w with same marginal on w.
"""
r = self.copy()
r.kernel_eliminate(w)
return r
def kernel_forall(self, w):
"""Set output, return result.
Currently we do not differentiate between channel union
and intersection. May change in the future.
"""
r = self.copy()
r.kernel_eliminate(w)
return r
def issymmetric(self, w, quick = False):
"""Check whether region is symmetric with respect to the variables in w
"""
csstr = ""
if quick:
csstr = self.tostring(tosort = True)
for i in range(1, len(w)):
t = self.copy()
tvar = Comp.rv("SYMM_TMP")
t.substitute(w[i], tvar)
t.substitute(w[0], w[i])
t.substitute(tvar, w[0])
if quick:
if t.tostring(tosort = True) != csstr:
return False
else:
if not t.implies(self):
return False
if not self.implies(t):
return False
return True
def __imul__(self, other):
compreal = self.allcompreal()
for a in compreal.varlist:
self.substitute(Expr.fromcomp(Comp([a])), Expr.fromcomp(Comp([a])) / other)
self.simplify_quick()
return self
def __mul__(self, other):
r = self.copy()
r *= other
return r
def __rmul__(self, other):
r = self.copy()
r *= other
return r
def __itruediv__(self, other):
self *= 1.0 / other
return self
def __truediv__(self, other):
r = self.copy()
r *= 1.0 / other
return r
def distribute(self, force_split = False):
return self
def sum_minkowski(self, other):
"""Minkowski sum of two regions with respect to their real variables.
"""
if other.get_type() != RegionType.NORMAL:
return other.sum_minkowski(self)
cs = self.copy()
co = other.copy()
param_real = cs.allcomprealvar().inter(co.allcomprealvar())
param_real_expr = Expr.zero()
for a in param_real.varlist:
newname = "TENSOR_TMP_" + a.name
cs.substitute(Expr.real(a.name), Expr.real(a.name) - Expr.real(newname))
co.substitute(Expr.real(a.name), Expr.real(newname))
param_real_expr += Expr.real(newname)
if PsiOpts.settings["tensorize_simplify"]:
return (cs & co).eliminated(param_real_expr)
else:
return (cs & co).eliminated_quick(param_real_expr)
def __add__(self, other):
return self.sum_minkowski(other)
def __iadd__(self, other):
return self + other
# def __invert__(self):
# if not self.imp_present():
# t = self.one_flipped()
# if t is not None:
# return t.forall(self.aux)
# return ~RegionOp.union([self])
# def negate(self):
# return ~self
def __invert__(self):
return ~RegionOp.union([self])
def try_negate(self, eps_only = False):
if not self.imp_present():
t = self.one_flipped(strict = True)
if t is not None and (not eps_only or not t.ispresent(Expr.eps())):
return t.forall(self.aux)
return None
def negate(self):
t = self.try_negate()
if t is not None:
return t
return ~RegionOp.union([self])
def forall_completed(self, exclude = None):
toforall = self.rvs - (self.aux + self.auxi)
if exclude is not None:
toforall -= exclude
return self.forall(toforall)
def noeq(self):
for a in self.exprs_eq:
if not a.isnonneg():
self.exprs_ge.append(a.copy())
if not a.isnonpos():
self.exprs_ge.append(-a)
a.setzero()
for a in self.exprs_eqi:
if not a.isnonneg():
self.exprs_gei.append(a.copy())
if not a.isnonpos():
self.exprs_gei.append(-a)
a.setzero()
self.exprs_eq = [x for x in self.exprs_eq if not x.iszero()]
self.exprs_eqi = [x for x in self.exprs_eqi if not x.iszero()]
return self
def copy_noeq(self):
r = self.copy()
r.noeq()
return r
def toregionop_split(self, force_split = False):
r = RegionOp.union([])
if self.imp_present():
r.regs.append((self.imp_flippedonly_noaux(), False))
t = self.consonly().copy_noaux()
r.regs.append((t.exists(self.aux.copy()), True))
return r.forall(self.auxi.copy())
else:
if force_split:
t = self.copy_noaux().copy_noeq()
if len(t.exprs_ge) == 1:
return self.copy()
return RegionOp.inter([a for a in t]).exists(self.aux.copy()).forall(self.auxi.copy())
else:
return RegionOp.inter([self.copy_noaux()]).exists(self.aux.copy()).forall(self.auxi.copy())
def copy_rename(self):
"""Return a copy with renamed variables, together with map from old name to new.
"""
namemap = {}
r = self.copy()
index = IVarIndex()
self.record_to(index)
param_rv = index.comprv.copy()
param_real = index.compreal.copy()
for a in param_rv.varlist:
name1 = index.name_avoid(a.name)
index.record(Comp.rv(name1))
namemap[a.name] = name1
r.rename_var(a.name, name1)
for a in param_real.varlist:
name1 = index.name_avoid(a.name)
index.record(Comp.real(name1))
namemap[a.name] = name1
r.rename_var(a.name, name1)
return (r, namemap)
# def nfold(self, bnet = None, n = 2, natural_eqprob = True, all_eqprob = True):
# vs = self.allcomprv_noaux()
# vmap = []
# for v in vs:
# vname = v.get_name()
# vmap[vname] = rv_array(vname, n)
def tensorize(self, reg_subset = None, chan_cond = None, nature = None, timeshare = False, hint_aux = None, same_dist = False):
"""Check whether region tensorizes, return auxiliary RVs if tensorizes.
chan_cond : The condition on the channel (e.g. degraded broadcast channel)
"""
for rr in self.tensorize_gen(reg_subset = reg_subset, chan_cond = chan_cond,
nature = nature, timeshare = timeshare, hint_aux = hint_aux, same_dist = same_dist):
if iutil.signal_type(rr) == "":
return rr
return None
def tensorize_gen(self, reg_subset = None, chan_cond = None, nature = None, timeshare = False, hint_aux = None, same_dist = False):
"""Check whether region tensorizes, yield all auxiliary RVs if tensorizes.
chan_cond : The condition on the channel (e.g. degraded broadcast channel)
"""
r2, namemap = self.copy_rename()
rx = None
if reg_subset is None:
rx = self.copy()
else:
rx = reg_subset.copy()
if chan_cond is None:
chan_cond = Region.universe()
chan_cond2 = chan_cond.copy()
chan_cond2.rename_map(namemap)
if nature is None:
nature = Comp.empty()
nature2 = nature.copy()
nature2.rename_map(namemap)
index = IVarIndex()
self.record_to(index)
#rx.record_to(index)
param_rv = index.comprv - self.getauxall()
param_real = index.compreal
param_rv_map = Comp.empty()
param_real_expr = Expr.zero()
for a in param_rv.varlist:
rx.substitute(Comp.rv(a.name), Comp.rv(a.name) + Comp.rv(namemap[a.name]))
param_rv_map += Comp.rv(namemap[a.name])
rsum = None
if timeshare:
rsum = self.sum_entrywise(r2)
for a in param_real.varlist:
rx.substitute(Expr.real(a.name), Expr.real(namemap[a.name]))
rsum.substitute(Expr.real(a.name), Expr.zero())
else:
for a in param_real.varlist:
r2.substitute(Expr.real(namemap[a.name]), Expr.real(namemap[a.name]) - Expr.real(a.name))
rx.substitute(Expr.real(a.name), Expr.real(namemap[a.name]))
param_real_expr += Expr.real(a.name)
rsum = self.copy()
rsum.iand_norename(r2)
if PsiOpts.settings["tensorize_simplify"]:
rsum = rsum.eliminated(param_real_expr)
else:
rsum = rsum.eliminated_quick(param_real_expr)
if rsum.get_type() == RegionType.NORMAL:
rsum.aux = self.aux.interleaved(r2.aux)
for a in param_real.varlist:
rx.substitute(Expr.real(namemap[a.name]), Expr.real(a.name))
rsum.substitute(Expr.real(namemap[a.name]), Expr.real(a.name))
#rx.rename_avoid(chan_cond)
chan_cond.aux_avoid_from(rx)
rx &= chan_cond
chan_cond2.aux_avoid_from(rx)
rx &= chan_cond2
chan_cond_comp = chan_cond.allcomprv()
chan_cond2_comp = chan_cond2.allcomprv()
for i in range(chan_cond_comp.size()):
namemap[chan_cond_comp.varlist[i].name] = chan_cond2_comp.varlist[i].name
rx.iand_norename(Expr.Ic(self.allcomprv() - self.inp - self.oup - self.aux - self.auxi,
r2.allcomprv() - r2.oup - r2.aux - r2.auxi,
self.inp) == 0)
rx.iand_norename(Expr.Ic(self.inp,
r2.allcomprv() - r2.inp - r2.oup - r2.aux - r2.auxi,
r2.inp) == 0)
if not nature.isempty():
rx.iand_norename(Expr.I(nature, nature2) == 0)
hint_pair = []
for (key, value) in namemap.items():
hint_pair.append((Comp.rv(key), Comp.rv(value)))
if same_dist:
rx.iand_norename(eqdist(param_rv, param_rv_map))
for rr in rx.implies_getaux_gen(rsum, hint_pair = hint_pair, hint_aux = hint_aux):
yield rr
def check_converse(self, reg_subset = None, chan_cond = None, nature = None, hint_aux = None):
"""Check whether self is the capacity region of the operational region.
reg_subset, return auxiliary RVs if true.
chan_cond : The condition on the channel (e.g. degraded broadcast channel).
"""
return self.tensorize(reg_subset, chan_cond, nature, True, hint_aux = hint_aux)
def check_converse_gen(self, reg_subset = None, chan_cond = None, nature = None, hint_aux = None):
"""Check whether self is the capacity region of the operational region.
reg_subset, yield all auxiliary RVs if true.
chan_cond : The condition on the channel (e.g. degraded broadcast channel).
"""
for rr in self.tensorize_gen(reg_subset, chan_cond, nature, True, hint_aux = hint_aux):
yield rr
def __xor__(self, other):
return self.eliminated(other)
def __ixor__(self, other):
return self.eliminate(other)
def isfeasible(self):
"""Whether this region is feasible.
"""
return not self.implies(Expr.one() <= 0)
def __bool__(self):
return self.check()
def __call__(self):
return self.check()
def truth_value(self):
"""Truth value of this region. Either True (proved to be True), False (proved to be False),
or None (cannot be proved to be True/False).
"""
if bool(self):
return True
elif bool(~self):
return False
return None
def assume(self):
"""Assume this region is true in the current context.
"""
PsiOpts.set_setting(truth_add = self)
def assume_only(self):
"""Assume this region is true in the current context. Overwrite existing assumptions.
"""
if self.isuniverse():
PsiOpts.set_setting(truth = None)
else:
PsiOpts.set_setting(truth = self.copy())
def assumed(self):
"""Create a context where this region is assumed to be true.
Use "with region.assumed(): ..."
"""
return PsiOpts(truth_add = self)
def assumed_only(self):
"""Create a context where this region is assumed to be true. Overwrite existing assumptions.
Use "with region.assumed_only(): ..."
"""
if self.isuniverse():
return PsiOpts(truth = None)
else:
return PsiOpts(truth = self.copy())
def proof(self, mode = None, **kwargs):
"""Get the proof of this region.
"""
if mode is None:
mode = True
r = None
with PsiOpts(proof_new = mode, **{"proof_" + key: val for key, val in kwargs.items()}):
bool(self)
r = PsiOpts.get_proof()
return r
def bound(self, expr, var, sgn = 0, minsize = 1, maxsize = 3, coeffmode = 1, skip_simplify = False):
"""Automatically discover bounds on expr in terms of variables in var.
Parameters:
sgn : Set to 1 for upper bound, -1 for lower bound, 0 for both.
minsize : Minimum number of terms in bound.
maxsize : Maximum number of terms in bound.
coeffmode : Set to 0 to only allow positive terms.
Set to 1 to allow positive/negative terms, but not all negative.
Set to 2 to allow positive/negative terms.
skip_simplify : Set to True to skip final simplification.
"""
if sgn == 0:
r = (self.bound(expr, var, sgn = 1, minsize = minsize, maxsize = maxsize,
coeffmode = coeffmode, skip_simplify = skip_simplify)
& self.bound(expr, var, sgn = -1, minsize = minsize, maxsize = maxsize,
coeffmode = coeffmode, skip_simplify = skip_simplify))
if not skip_simplify:
r.simplify_quick()
return r
varreal = []
varrv = []
if isinstance(var, list):
for v in var:
if isinstance(v, Expr):
varreal.append(v)
else:
varrv.append(v)
else:
for v in var.allcomp():
if v.get_type() == IVarType.REAL:
varreal.append(Expr.fromcomp(v))
elif v.get_type() == IVarType.RV:
varrv.append(v)
fcn = None
if sgn > 0:
fcn = lambda x: self.implies(expr <= x)
elif sgn < 0:
fcn = lambda x: self.implies(expr >= x)
s = None
for _, st in igen.test(igen.subset(itertools.chain(igen.sI(varrv), varreal), minsize = minsize,
maxsize = maxsize, coeffmode = coeffmode),
fcn, sgn = sgn, yield_set = True):
s = st
if s is None:
return Region.universe()
r = Region.universe()
for x in s:
if sgn > 0:
r.exprs_ge.append(x - expr)
elif sgn < 0:
r.exprs_ge.append(expr - x)
if not skip_simplify:
r.simplify()
return r
def union_list(self, distribute = True):
return [self.copy()]
def order(self, entries):
"""Discover the partial ordering of the entries. Returns a PartialOrder object."""
r = PartialOrder(self)
for x in entries:
r.add(x)
return r
class SearchEntry:
def __init__(self, x, x_reg = None, cmin = -1, cmax = 1):
self.x = x
if x_reg == None:
self.x_reg = x.copy()
else:
self.x_reg = x_reg
self.cmin = cmin
self.cmax = cmax
def discover(self, entries, method = "hull_auto", minsize = 1, maxsize = 2, skip_simplify = False, reg_init = None, skipto_ex = None, toreal_prefix = None, balanced = False, init_pts_outer = None, pts_outer = None, entries_final = None, dual_bound_values = "shared_rv", dual_bound_numiter = 1):
"""Automatically discover inequalities between entries.
Parameters:
entries : List of variables of interest.
minsize : Minimum number of terms in bound.
maxsize : Maximum number of terms in bound.
skip_simplify : Set to True to skip final simplification.
"""
if method == "dual":
method = "hull_cone_dual"
if method == "dual_l1":
method = "hull_cone_dual_l1"
ceps = PsiOpts.settings["eps"]
truth = PsiOpts.settings["truth"]
if truth is not None:
with PsiOpts(truth = None):
return (self & truth).discover(entries, method, minsize, maxsize, skip_simplify, reg_init, skipto_ex,
toreal_prefix, balanced, init_pts_outer, pts_outer, entries_final, dual_bound_values, dual_bound_numiter)
indreg = self.get_indreg_checked(reg_add = entries)
if indreg is not None:
with PsiOpts(indreg_enabled = False):
return (self & indreg).discover(entries, method, minsize, maxsize, skip_simplify, reg_init, skipto_ex,
toreal_prefix, balanced, init_pts_outer, pts_outer, entries_final, dual_bound_values, dual_bound_numiter)
verbose = PsiOpts.settings.get("verbose_discover", False)
verbose_detail = PsiOpts.settings.get("verbose_discover_detail", False)
verbose_terms = PsiOpts.settings.get("verbose_discover_terms", False)
verbose_terms_inner = PsiOpts.settings.get("verbose_discover_terms_inner", False)
verbose_terms_outer = PsiOpts.settings.get("verbose_discover_terms_outer", False)
simp_step = 10
varreal = []
varrv = []
varrv_markers = []
maxlen = 1
css = [cs.simplified_quick() for cs in self.union_list()]
# cs = self.simplified_quick()
#plain = cs.isplain()
plain = True
selfifs = None
progs_self = []
progs_r = []
index_self = IVarIndex()
for cs in css:
cs.record_to(index_self)
index_r = IVarIndex()
isaffine = any(cs.affine_present() for cs in css)
if isinstance(entries, dict):
entries = list(entries.items())
if isinstance(entries, Expr):
# entries = entries.allcomp()
entries = [entries]
entries2 = []
for a in entries:
if isinstance(a, tuple) and len(a) and a[0] is None:
continue
if isinstance(a, (CompArray, ExprArray)):
for b in a:
entries2.append(b)
else:
entries2.append(a)
entries = entries2
maxent_comps = []
nonmaxent_present = False
for a in entries:
a2 = a
cmarkers = {}
if balanced:
cmarkers["balanced"] = True
if isinstance(a, tuple):
a2 = []
for a0 in a:
if isinstance(a0, str):
cmarkers[a0] = True
else:
a2.append(a0)
if len(a2) == 1:
a2 = a2[0]
else:
a2 = tuple(a2)
if not isinstance(a2, tuple):
if isinstance(a2, Expr) and toreal_prefix is not None:
a2 = (a2, Expr.real(toreal_prefix + str(a2)))
else:
a2 = (a2, None)
if isinstance(a2[1], ExprArray) or isinstance(a2[1], CompArray):
maxlen = max(maxlen, len(a2[1]))
a2[0].record_to(index_r)
aself = None
if a2[1] is None:
aself = a2[0]
else:
aself = a2[1]
aself.record_to(index_self)
if isinstance(aself, Expr) and aself.affine_present():
isaffine = True
plain = plain and not aself.isregtermpresent()
if isinstance(aself, Expr):
for t in aself:
tt = t.get_maxent_comp()
if tt is not None:
if tt not in maxent_comps:
maxent_comps.append(tt)
else:
nonmaxent_present = True
else:
nonmaxent_present = True
if isinstance(a2[0], Expr):
varreal.append(a2[0])
else:
varrv.append(a2[0])
varrv_markers.append(cmarkers)
if maxent_comps and not nonmaxent_present:
if PsiOpts.settings["maxent_lex_enabled"]:
with PsiOpts(maxent_lex_enabled = False):
r = Region.universe()
for cs in self.maxent_lex_region_gen([(x, False) for x in maxent_comps], varadd = index_self.comprv):
# print(cs)
t = cs.discover(entries, method, minsize, maxsize, skip_simplify, reg_init, skipto_ex,
toreal_prefix, balanced, init_pts_outer, pts_outer)
r &= t
return r.simplified_quick()
for cs in css:
if cs.get_type() == RegionType.NORMAL and not cs.aux.isempty() and not cs.imp_present():
cs.aux_strengthen(index_self.comprv)
cs.aux = Comp.empty()
plain = plain and all(cs.isplain() for cs in css)
r = Region.universe()
if reg_init is not None:
r = reg_init.copy()
#print(skipto_ex)
if skipto_ex is not None:
if not isinstance(skipto_ex, str):
skipto_ex = skipto_ex.tostring()
nadd = 0
if plain:
selfifs = [cs.imp_flipped() for cs in css]
if method == "hull_auto":
if isaffine:
method = "hull"
else:
method = "hull_cone"
if (method == "hull" or method == "hull_cone" or method == "hull_cone_dual" or method == "hull_cone_dual_l1") and not plain:
method = "guess"
#print(index_r.comprv)
#print(index_self.comprv)
if method == "semigraphoid" or method == "ic" or method == "ci":
method = "semigraphoid"
perform_semigraphoid = method == "semigraphoid" or method == "hull" or method == "hull_cone" or method == "hull_cone_dual" or method == "hull_cone_dual_l1"
if len(css) > 1:
perform_semigraphoid = False
sg = None
if perform_semigraphoid:
tvs = []
for a in entries:
if isinstance(a, tuple) and isinstance(a[1], Comp):
tvs.append((a[0], a[1]))
elif isinstance(a, Comp):
tvs.append(a)
sg = css[0].completed_semigraphoid(vs = tvs)
if method == "semigraphoid":
return sg
# sg = None
vis = set()
terms = []
if method == "hull" or method == "hull_cone" or method == "hull_cone_dual" or method == "hull_cone_dual_l1":
if sg is None or sg.isuniverse():
terms = list(itertools.chain(varreal, ent_vector(*varrv)))
else:
terms = list(itertools.chain(varreal, sg.get_basis(more_vars = sum(varrv, Comp.empty()))))
else:
terms = list(itertools.chain(varreal, igen.sI(varrv)))
for a, markers in zip(varrv, varrv_markers):
if markers.get("balanced", False):
for t in terms:
c = t.get_coeff(a)
if abs(c) > ceps:
t += Expr.H(a) * -c
t.simplify_quick()
terms = [t for t in terms if not t.iszero()]
if verbose:
print("discover nterms = " + str(len(terms)))
if verbose_terms:
for t in terms:
print(t)
print()
lastres = [False]
def expr_tr(ex):
exsum = Expr.zero()
for i in range(maxlen):
tex = ex.copy()
for a in entries:
if isinstance(a, tuple):
if isinstance(a[1], ExprArray) or isinstance(a[1], CompArray):
if i < len(a[1]):
tex.substitute(a[0], a[1][i])
else:
if isinstance(a[0], Expr):
tex.substitute(a[0], Expr.zero())
else:
tex.substitute(a[0], Comp.empty())
else:
tex.substitute(a[0], a[1])
exsum += tex
return exsum
if entries_final is not None:
entries_final += [(a, expr_tr(a)) for a in terms]
#print(method)
if method == "hull_cone_dual" or method == "hull_cone_dual_l1":
progs = []
dual_discover_l1 = (method == "hull_cone_dual_l1")
for cs in css:
cs2 = cs.copy()
# for i, a in enumerate(terms):
# cs2.iand_norename((a == 0).add_meta("dual_discover", i))
P = dual_bound_values
if isinstance(dual_bound_values, str):
if dual_bound_values == "example_bsc":
P = cs2.example_bsc(addrv = index_self.comprv)
elif dual_bound_values == "shared_rv":
P = SharedRVModel(cs2, addrv = index_self.comprv)
cs2.add_real_to(P)
with PsiOpts(lp_dual_form = True, lp_dual_form_discover = True):
prog = cs2.imp_flipped().init_prog(index = index_self, lp_bounded = False,
discover_exprs = [expr_tr(a) for a in terms], discover_init_model = P, discover_method = "l1" if dual_discover_l1 else "")
progs.append(prog)
m = len(terms)
def toexpr(x):
r = Expr.zero()
for i in range(m):
if abs(x[i]) > ceps:
r += terms[i] * x[i]
return r
def cprog(x):
ropt = None
rr = None
for prog in progs:
# print(x)
# print(prog.vdiscover)
n = prog.dual_form_ncons
c = [0.0] * n
if dual_discover_l1:
for i in range(m):
c[prog.vdiscover[i * 2]] += x[i]
c[prog.vdiscover[i * 2 + 1]] -= x[i]
else:
for i in range(m):
c[prog.vdiscover[i]] += x[i]
ret = (None, None)
disabled = set()
prog.vdiscover_disabled = set()
disable_candidates = {None}
for it in range(dual_bound_numiter):
for cd in disable_candidates:
if cd is not None:
prog.vdiscover_disabled = disabled.union({cd})
else:
prog.vdiscover_disabled = set(disabled)
prog.apply_discover_bounds()
# print(prog.solver_param["model"].ExportModelAsLpFormat(False).replace('\\', '').replace(',_', ','))
opt, v = prog.call_prog(c)
if opt is None:
continue
ret = (opt, v)
disabled.add(cd)
disable_candidates = prog.active_discover_bounds(v)
if verbose_detail:
def tostr(i):
if i == -1:
return "SUM"
else:
return str(terms[i])
if cd is not None:
print("DISABLE: " + tostr(cd))
if disable_candidates:
print("DISABLE CANDIDATES: " + ", ".join([tostr(i) for i in disable_candidates]))
break
else:
break
if len(disable_candidates) <= 1:
break
opt, v = ret
# print(c)
# opt, v = prog.call_prog(c)
# print(v)
# print(prog.dual_form_ncons)
# print(prog.dual_form)
# print()
if opt is None:
return (None, None)
r = [0.0] * m
if dual_discover_l1:
for i in range(m):
r[i] = v[prog.vdiscover[i * 2]] - v[prog.vdiscover[i * 2 + 1]]
else:
for i in range(m):
r[i] = v[prog.vdiscover[i]]
return (opt, r)
rt = LinearProg.proj_hull(cprog, m, toexpr = toexpr, iscone = True, is_dual_discover = True)
if rt is None:
return None
inf_thres = progs[0].lp_ubound / 5.0
for x in rt:
if abs(x[0]) > inf_thres:
continue
expr = Expr.zero()
for i in range(len(terms)):
if abs(x[i + 1]) > ceps:
expr += terms[i] * x[i + 1]
r.iand_norename(expr >= -x[0])
if not (sg is None or sg.isuniverse()):
r &= sg
if not skip_simplify:
r.simplify()
return r
elif method == "hull" or method == "hull_cone":
# prog = cs.imp_flipped().init_prog(index = index_self, lp_bounded = True)
# A = SparseMat(0)
# for a in terms:
# A.extend(prog.get_vec(expr_tr(a), sparse = True)[0])
# rt = prog.discover_hull(A, iscone = (method == "hull_cone"), init_pts_outer = init_pts_outer, pts_outer = pts_outer)
progs = [cs.imp_flipped().init_prog(index = index_self, lp_bounded = True) for cs in css]
Ams = []
for prog in progs:
A = SparseMat(0)
for a in terms:
A.extend(prog.get_vec(expr_tr(a), sparse = True)[0])
Ams.append(A)
m = len(terms)
def toexpr(x):
r = Expr.zero()
for i in range(m):
if abs(x[i]) > ceps:
r += terms[i] * x[i]
return r
def cprog(x):
ropt = None
rr = None
for A, prog in zip(Ams, progs):
n = prog.nxvar
c = [0.0] * n
for i in range(m):
for j, a in A.x[i]:
c[j] += a * x[i]
opt, v = prog.call_prog(c)
if opt is None:
return (None, None)
r = [0.0] * m
for i in range(m):
for j, a in A.x[i]:
r[i] += a * v[j]
# return (opt, r)
if ropt is None or opt < ropt:
ropt = opt
rr = r
return (ropt, rr)
rt = LinearProg.proj_hull(cprog, m, toexpr = toexpr, iscone = (method == "hull_cone"), init_pts_outer = init_pts_outer, pts_outer = pts_outer)
if rt is None:
return None
inf_thres = progs[0].lp_ubound / 5.0
for x in rt:
if abs(x[0]) > inf_thres:
continue
expr = Expr.zero()
for i in range(len(terms)):
if abs(x[i + 1]) > ceps:
expr += terms[i] * x[i + 1]
r.iand_norename(expr >= -x[0])
if not (sg is None or sg.isuniverse()):
r &= sg
if not skip_simplify:
r.simplify()
return r
def exgen():
for size in range(minsize, maxsize + 1):
if size == 1:
for i in range(len(terms)):
yield terms[i]
res0 = lastres[0]
yield -terms[i]
res1 = lastres[0]
if res0 and res1:
terms[i] = None
terms[:] = [a for a in terms if a is not None]
elif size == 2:
for i in range(len(terms)):
for j in range(i):
if terms[j] is not None:
yield terms[i] - terms[j]
res0 = lastres[0]
yield terms[j] - terms[i]
res1 = lastres[0]
if res0 and res1:
terms[i] = None
break
terms[:] = [a for a in terms if a is not None]
else:
for ex in igen.subset(terms, minsize = size,
maxsize = maxsize, coeffmode = -1, replacement = True):
yield ex
break
for ex in exgen():
if PsiOpts.is_timer_ended():
break
if skipto_ex is not None:
if skipto_ex != ex.tostring():
lastres[0] = False
continue
skipto_ex = None
if verbose_terms:
print(str(ex) + " <= 0")
ex = ex.simplified_quick()
if ex.isnonpos():
lastres[0] = True
continue
exhash = hash(ex)
if exhash in vis:
lastres[0] = False
continue
vis.add(exhash)
r2 = ex <= 0
if r.implies_saved(r2, index_r, progs_r):
lastres[0] = True
continue
exsum = expr_tr(ex)
if ((plain and all(selfif.implies_impflipped_saved(exsum <= 0, index_self, progs_self) for selfif in selfifs))
or (not plain and all(cs.implies(exsum <= 0) for cs in css))):
if verbose:
print("ADD " + str(ex) + " <= 0")
r &= r2
nadd += 1
if not skip_simplify and nadd % simp_step == 0:
r.simplify()
progs_r = []
#print("OKAY " + str(r2))
if verbose_detail:
print(str(r))
lastres[0] = True
else:
#print("FAIL " + str(r2) + " " + str(plain) + " " + str(saved_info))
lastres[0] = False
if not skip_simplify:
r.simplify()
return r
def hull(self, entries = None, **kwargs):
"""Find the convex hull outer bound of this region.
"""
if entries is None:
entries = []
return self.discover(list(self.rvs) + list(self.reals) + list(entries), **kwargs)
def extreme_rays(self, *args, **kwargs):
"""Find the extreme rays of a list of real variables.
"""
ceps = PsiOpts.settings["eps"]
pts_outer = []
self.discover(*args, pts_outer = pts_outer, **kwargs)
r = []
for p in pts_outer:
if sum(abs(x) for x in p) <= ceps:
continue
m = max(abs(x) for x in p)
r.append([x / m for x in p])
return r
def dual_cone(self, entries = None, partition = None, sgn = 1, skip_simplify = False, *args, **kwargs):
"""Find the dual cone of a list of real variables.
"""
ceps = PsiOpts.settings["eps"]
if entries is None:
entries = list(self.reals)
entries_final = []
rays = self.extreme_rays(entries, skip_simplify = skip_simplify, entries_final = entries_final)
r = Region.universe()
if partition is None:
partition = [None]
partition = [None if y is None else set(str(x) for x in y) for y in partition]
for ray in rays:
for part in partition:
expr = Expr.zero()
for x, (a, b) in zip(ray, entries_final):
if part is not None and str(a) not in part:
continue
if abs(x) > ceps:
expr += a * x
r &= expr * sgn >= 0
if not skip_simplify:
r.simplify()
return r
def polar_cone(self, entries = None, partition = None, sgn = 1, skip_simplify = False, *args, **kwargs):
"""Find the polar cone of a list of real variables.
"""
return self.dual_cone(entries, partition, -sgn, skip_simplify, *args, **kwargs)
@staticmethod
def markov_tostring(cm, style = 0, tosort = False):
if len(cm) % 2 == 1 and all(cm[i].isempty() for i in range(1, len(cm), 2)):
tlist = [cm[i].tostring(style = style, tosort = tosort,
add_bracket = not (style & PsiOpts.STR_STYLE_PSITIP)) for i in range(0, len(cm), 2)]
if style & PsiOpts.STR_STYLE_LATEX:
return (" " + PsiOpts.settings["latex_indep"] + " ").join(tlist)
else:
return "indep(" + ", ".join(tlist) + ")"
else:
tlist = [cm[i].tostring(style = style, tosort = tosort,
add_bracket = not (style & PsiOpts.STR_STYLE_PSITIP)) for i in range(len(cm))]
if style & PsiOpts.STR_STYLE_LATEX:
return (" " + PsiOpts.settings["latex_markov"] + " ").join(tlist)
else:
return "markov(" + ", ".join(tlist) + ")"
def tostring(self, style = 0, tosort = False, lhsvar = None, inden = 0, add_bracket = False,
small = False, skip_outer_exists = False):
"""Convert to string.
Parameters:
style : Style of string conversion
STR_STYLE_STANDARD : I(X,Y;Z|W)
STR_STYLE_PSITIP : I(X+Y&Z|W)
"""
if lhsvar is None:
lhsvar = PsiOpts.settings["lhsvar"]
style = iutil.convert_str_style(style)
r = ""
pf_note = PsiOpts.settings["str_proof_note"]
if pf_note:
note_pre = self.get_meta("pf_note_pre")
if note_pre is not None:
cs = self.copy()
cs.remove_meta("pf_note_pre")
r += cs.tostring(style = style, tosort = tosort,
lhsvar = lhsvar, inden = inden,
add_bracket = add_bracket, small = small,
skip_outer_exists = skip_outer_exists)
r = [iutil.pf_note_str(x, style, add_bracket = False) for x in note_pre] + [r]
if style & PsiOpts.STR_STYLE_LATEX:
r = iutil.latex_split_line(r, None)
else:
if isinstance(r, list):
r = "\n".join(r)
return r
if isinstance(lhsvar, str) and lhsvar == "real":
lhsvar = self.allcomprealvar()
nlstr = "\n"
if style & PsiOpts.STR_STYLE_LATEX:
nlstr = "\\\\\n"
spacestr = " "
if style & PsiOpts.STR_STYLE_LATEX:
spacestr = "\\;"
if skip_outer_exists and not self.aux.isempty() and self.auxi.isempty() and self.isuniverse():
return spacestr * inden + self.aux.tostring(style = style, tosort = tosort)
if self.isuniverse(sgn = False, canon = True):
if style & PsiOpts.STR_STYLE_PSITIP:
return spacestr * inden + "empty()"
elif style & PsiOpts.STR_STYLE_LATEX:
return spacestr * inden + PsiOpts.settings["latex_region_empty"]
imp_pres = False
if not self.auxi.isempty():
if style & PsiOpts.STR_STYLE_PSITIP:
pass
else:
if style & PsiOpts.STR_STYLE_LATEX:
if not style & PsiOpts.STR_STYLE_LATEX_QUANTAFTER:
r += PsiOpts.settings["latex_forall"] + " "
r += self.auxi.tostring(style = style, tosort = tosort)
r += PsiOpts.settings["latex_quantifier_sep"] + " "
if self.exprs_gei or self.exprs_eqi or not self.auxi.isempty():
add_bracket_inner = False
inden_inner2 = inden
curadd_bracket = bool(add_bracket or (not self.auxi.isempty() and (self.exprs_gei or self.exprs_eqi)))
if style & PsiOpts.STR_STYLE_LATEX:
if curadd_bracket:
r += "\\left\\{"
else:
r += "("
if self.exprs_gei or self.exprs_eqi:
imp_pres = True
inden_inner = inden
inden_inner1 = inden + 1
inden_inner2 = inden + 3
add_bracket_inner = bool(style & PsiOpts.STR_STYLE_PSITIP)
r += spacestr * inden
if style & PsiOpts.STR_STYLE_LATEX:
r += "\\begin{array}{l}\n"
if not small:
r += r"\displaystyle"
r += " "
inden_inner = 0
inden_inner1 = 0
inden_inner2 = 0
else:
pass
r += self.imp_flippedonly_noaux().tostring(style = style, tosort = tosort,
lhsvar = lhsvar, inden = inden_inner1,
add_bracket = add_bracket_inner).lstrip()
if style & PsiOpts.STR_STYLE_PSITIP:
r += nlstr + spacestr * inden_inner + ">> "
elif style & PsiOpts.STR_STYLE_LATEX:
r += (nlstr + (r"\displaystyle " if not small else " ") + spacestr * inden_inner
+ PsiOpts.settings["latex_matimplies"] + " " + spacestr + " ")
else:
r += nlstr + spacestr * inden_inner + "=> "
cs = self.copy()
cs.exprs_gei = []
cs.exprs_eqi = []
cs.auxi = Comp.empty()
r += cs.tostring(style = style, tosort = tosort,
lhsvar = lhsvar, inden = inden_inner2,
add_bracket = add_bracket_inner).lstrip()
if not self.auxi.isempty():
if style & PsiOpts.STR_STYLE_PSITIP:
pass
else:
if style & PsiOpts.STR_STYLE_LATEX:
if style & PsiOpts.STR_STYLE_LATEX_QUANTAFTER:
r += " , " + PsiOpts.settings["latex_forall"] + " "
r += self.auxi.tostring(style = style, tosort = tosort)
else:
r += " , forall "
r += self.auxi.tostring(style = style, tosort = tosort)
if style & PsiOpts.STR_STYLE_LATEX:
if self.exprs_gei or self.exprs_eqi:
r += nlstr + "\\end{array}"
if curadd_bracket:
r += "\\right\\}"
else:
r += ")"
if not self.auxi.isempty():
if style & PsiOpts.STR_STYLE_PSITIP:
r += ".forall(" + self.auxi.tostring(style = style, tosort = tosort) + ")"
return r
if not self.aux.isempty():
if style & PsiOpts.STR_STYLE_PSITIP:
pass
else:
if style & PsiOpts.STR_STYLE_LATEX:
if not style & PsiOpts.STR_STYLE_LATEX_QUANTAFTER:
if not skip_outer_exists:
r += PsiOpts.settings["latex_exists"] + " "
r += self.aux.tostring(style = style, tosort = tosort)
r += PsiOpts.settings["latex_quantifier_sep"] + " "
cs = self
bnets = None
if style & PsiOpts.STR_STYLE_MARKOV:
cs = cs.copy()
tic = cs.remove_ic()
bnets = BayesNet.from_ic_list(tic)
eqnlist = ([x.tostring_eqn(">=", style = style, tosort = tosort, lhsvar = lhsvar) for x in cs.exprs_ge]
+ [x.tostring_eqn("==", style = style, tosort = tosort, lhsvar = lhsvar) for x in cs.exprs_eq])
if tosort:
eqnlist = zip(eqnlist, [lhsvar is not None and any(x.ispresent(t) for t in lhsvar) for x in cs.exprs_ge]
+ [lhsvar is not None and any(x.ispresent(t) for t in lhsvar) for x in cs.exprs_eq])
eqnlist = sorted(eqnlist, key=lambda a: (not a[1], len(a[0]), a[0]))
eqnlist = [x for x, t in eqnlist]
if style & PsiOpts.STR_STYLE_MARKOV:
eqnlist2 = []
for bnet in bnets:
ms = bnet.get_markov()
for cm in ms:
eqnlist2.append(Region.markov_tostring(cm, style, tosort))
# if len(cm) % 2 == 1 and all(cm[i].isempty() for i in range(1, len(cm), 2)):
# tlist = [cm[i].tostring(style = style, tosort = tosort,
# add_bracket = not (style & PsiOpts.STR_STYLE_PSITIP)) for i in range(0, len(cm), 2)]
# if style & PsiOpts.STR_STYLE_LATEX:
# eqnlist2.append((" " + PsiOpts.settings["latex_indep"] + " ").join(tlist))
# else:
# eqnlist2.append("indep(" + ", ".join(tlist) + ")")
# else:
# tlist = [cm[i].tostring(style = style, tosort = tosort,
# add_bracket = not (style & PsiOpts.STR_STYLE_PSITIP)) for i in range(len(cm))]
# if style & PsiOpts.STR_STYLE_LATEX:
# eqnlist2.append((" " + PsiOpts.settings["latex_markov"] + " ").join(tlist))
# else:
# eqnlist2.append("markov(" + ", ".join(tlist) + ")")
if tosort:
eqnlist2 = sorted(eqnlist2, key = lambda a: len(a))
eqnlist += eqnlist2
first = True
use_array = style & PsiOpts.STR_STYLE_LATEX_ARRAY and len(eqnlist) > 1
isplu = True
if style & PsiOpts.STR_STYLE_LATEX:
isplu = len(eqnlist) > 1 and use_array
else:
isplu = not self.aux.isempty() or not self.auxi.isempty() or len(eqnlist) > 1
use_bracket = add_bracket or isplu
if style & PsiOpts.STR_STYLE_PSITIP:
r += spacestr * inden
if use_bracket:
r += "("
elif style & PsiOpts.STR_STYLE_LATEX:
r += spacestr * inden
if use_bracket:
r += "\\left\\{"
if use_array:
r += "\\begin{array}{l}\n"
else:
r += spacestr * inden
if use_bracket:
r += PsiOpts.settings["str_brace_l"]
for x in eqnlist:
if style & PsiOpts.STR_STYLE_PSITIP:
if first:
if use_bracket:
r += " "
else:
r += nlstr + spacestr * inden + " &"
if isplu:
r += "("
if use_bracket:
r += " "
elif style & PsiOpts.STR_STYLE_LATEX:
if use_array:
if first:
r += spacestr * inden + " "
else:
r += "," + nlstr + spacestr * inden + " "
if small:
# r += "{\\scriptsize "
pass
else:
if first:
r += " "
else:
r += "," + spacestr + " "
else:
if first:
if use_bracket:
r += " "
else:
r += PsiOpts.settings["str_ineq_sep"] + nlstr + spacestr * inden + " "
r += x
if style & PsiOpts.STR_STYLE_PSITIP:
r += " "
if isplu:
r += ")"
elif style & PsiOpts.STR_STYLE_LATEX:
if use_array:
if small:
# r += "}"
pass
first = False
if len(eqnlist) == 0:
if r != "":
r += " "
if style & PsiOpts.STR_STYLE_PSITIP:
r += "universe()"
elif style & PsiOpts.STR_STYLE_LATEX:
r += PsiOpts.settings["latex_region_universe"]
if style & PsiOpts.STR_STYLE_PSITIP:
if use_bracket:
r += " )"
elif style & PsiOpts.STR_STYLE_LATEX:
if use_array:
r += nlstr + spacestr * inden + "\\end{array}"
if use_bracket:
r += " \\right\\}"
else:
if use_bracket:
r += " " + PsiOpts.settings["str_brace_r"]
if not self.aux.isempty():
if style & PsiOpts.STR_STYLE_PSITIP:
pass
else:
if style & PsiOpts.STR_STYLE_LATEX:
if style & PsiOpts.STR_STYLE_LATEX_QUANTAFTER:
r += " , " + PsiOpts.settings["latex_exists"] + " "
r += self.aux.tostring(style = style, tosort = tosort)
else:
r += " , exists "
r += self.aux.tostring(style = style, tosort = tosort)
if imp_pres:
r += ")"
if not self.aux.isempty():
if style & PsiOpts.STR_STYLE_PSITIP:
r += ".exists(" + self.aux.tostring(style = style, tosort = tosort) + ")"
return r
def __str__(self):
lhsvar = None
# if PsiOpts.settings.get("str_lhsreal", False):
# lhsvar = "real"
return self.tostring(PsiOpts.settings["str_style"],
tosort = PsiOpts.settings["str_tosort"], lhsvar = lhsvar)
def tostring_repr(self, style):
if PsiOpts.settings.get("repr_check", False):
#return str(self.check())
if self.check():
return str(True)
lhsvar = None
# if PsiOpts.settings.get("str_lhsreal", False):
# lhsvar = "real"
if PsiOpts.settings.get("repr_simplify", False):
return self.simplified_quick().tostring(style,
tosort = PsiOpts.settings["str_tosort"], lhsvar = lhsvar)
return self.tostring(style,
tosort = PsiOpts.settings["str_tosort"], lhsvar = lhsvar)
def __repr__(self):
return self.tostring_repr(PsiOpts.settings["str_style_repr"])
@latex_postprocess
def _latex_(self):
return self.tostring_repr(iutil.convert_str_style("latex"))
def __hash__(self):
#return hash(self.tostring(tosort = True))
return hash((
hash(frozenset(hash(x) for x in self.exprs_ge)),
hash(frozenset(hash(x) for x in self.exprs_eq)),
hash(frozenset(hash(x) for x in self.exprs_gei)),
hash(frozenset(hash(x) for x in self.exprs_eqi)),
hash(self.aux), hash(self.inp), hash(self.oup), hash(self.auxi)
))
def ent_vector_discover_ic(v, x, eps = None, skip_simplify = False):
if eps is None:
eps = PsiOpts.settings["eps_check"]
n = len(x)
mask_all = (1 << n) - 1
if len(v) == (1 << n) - 1:
v = [0.0] + list(v)
r = Region.universe()
for mask in range(1 << n):
for i in range(n):
if not (1 << i) & mask:
if v[mask | (1 << i)] - v[mask] <= eps:
r.exprs_eq.append(Expr.Hc(x[i], x.from_mask(mask)))
for mask in range(1, 1 << n):
bins = dict()
for xmask in igen.subset_mask(mask_all - mask):
t = v[mask | xmask] - v[xmask]
if t <= eps:
continue
tbin0 = int(t / eps + 0.5)
for tbin in range(max(tbin0 - 1, 0), tbin0 + 2):
if tbin in bins:
for ymask in bins[tbin]:
if ymask & xmask == ymask and xmask - ymask < mask:
r.exprs_eq.append(Expr.Ic(x.from_mask(mask),
x.from_mask(xmask - ymask),
x.from_mask(ymask)))
if tbin == tbin0:
bins[tbin].append(xmask)
else:
if tbin == tbin0:
bins[tbin] = [xmask]
if not skip_simplify:
r.simplify_bayesnet(reduce_ic = True)
r.simplify()
return r
class RegionOp(Region):
"""A region which is the union/intersection of a list of regions."""
def __init__(self, rtype, regs, auxs, inp = None, oup = None, meta = None):
self.rtype = rtype
self.regs = regs
self.auxs = auxs
self.inp = Comp.empty() if inp is None else inp
self.oup = Comp.empty() if oup is None else oup
self.meta = meta
def get_type(self):
return self.rtype
def isnormalcons(self):
return False
def isuniverse(self, sgn = True, canon = False):
if canon:
if sgn:
return self.get_type() == RegionType.INTER and len(self.regs) == 0 and len(self.auxs) == 0
else:
return self.get_type() == RegionType.UNION and len(self.regs) == 0 and len(self.auxs) == 0
#return False
isunion = (self.get_type() == RegionType.UNION)
for x, c in self.regs:
if isunion ^ sgn ^ x.isuniverse(not c ^ sgn):
return not isunion ^ sgn
return isunion ^ sgn
def isempty(self):
return self.isuniverse(False)
def iseq(self):
"""Is this pure equality.
"""
return False
def isineq(self):
"""Is this pure inequality.
"""
return False
def expr(self):
"""Returns the sum of the expressions in this region.
"""
return sum((x.expr() for x, c in self.regs), Expr.zero())
def isplain(self):
return False
def copy(self):
return RegionOp(self.rtype, [(x.copy(), c) for x, c in self.regs],
[(x.copy(), c) for x, c in self.auxs], self.inp.copy(), self.oup.copy(), iutil.copy(self.meta))
def copy_noaux(self):
return RegionOp(self.rtype, [(x.copy_noaux(), c) for x, c in self.regs],
[], self.inp.copy(), self.oup.copy(), iutil.copy(self.meta))
def noaux(self):
return self.copy_noaux()
def copy_(self, other):
self.rtype = other.rtype
self.regs = [(x.copy(), c) for x, c in other.regs]
self.auxs = [(x.copy(), c) for x, c in other.auxs]
self.inp = other.inp.copy()
self.oup = other.oup.copy()
self.meta = iutil.copy(other.meta)
def imp_flip(self):
if self.get_type() == RegionType.INTER:
self.rtype = RegionType.UNION
elif self.get_type() == RegionType.UNION:
self.rtype = RegionType.INTER
for x, c in self.regs:
x.imp_flip()
return self
def imp_flipped(self):
r = self.copy()
r.imp_flip()
return r
def universe_type(rtype):
return RegionOp(rtype, [(Region.universe(), True)], [])
def union(xs = None, tosimple = False):
if xs is None:
xs = []
if tosimple and len(xs) == 1:
return xs[0].copy()
return RegionOp(RegionType.UNION, [(x.copy(), True) for x in xs], [])
def inter(xs = None, tosimple = False):
if xs is None:
xs = []
if tosimple and len(xs) == 1:
return xs[0].copy()
return RegionOp(RegionType.INTER, [(x.copy(), True) for x in xs], [])
def __len__(self):
return len(self.regs)
def __getitem__(self, key):
r = self.regs[key]
if isinstance(r, list):
return RegionOp(self.rtype, r, [])
else:
if r[1]:
return r[0]
else:
return RegionOp(self.rtype, [r], [])
def union_list(self, distribute = True):
if distribute:
r = self.copy()
r.distribute()
return r.union_list(distribute = False)
if self.get_type() == RegionType.UNION:
return list(self.copy())
return [self.copy()]
def add_meta(self, key, value, children = True):
if not children:
return IBaseObj.add_meta(self, key, value)
for x, c in self.regs:
x.add_meta(key, value, children = children)
return self
def get_meta(self, key):
t = IBaseObj.get_meta(self, key)
if t is not None:
return t
for x, c in self.regs:
t = x.get_meta(key)
if t is not None:
return t
return None
def remove_meta(self, key):
IBaseObj.remove_meta(self, key)
for x, c in self.regs:
x.remove_meta(key)
return self
def ispresent(self, x):
"""Return whether any variable in x appears here."""
for z, c in self.regs:
if z.ispresent(x):
return True
for z, c in self.auxs:
if z.ispresent(x):
return True
return False
def __contains__(self, other):
x = other
if isinstance(x, str):
if x == "exists":
if any(c for z, c in self.auxs):
return True
if x == "forall":
if any(not c for z, c in self.auxs):
return True
for z, c in self.regs:
if x in z:
return True
for z, c in self.auxs:
if x in z:
return True
return False
def allcomprealvar(self):
r = Comp.empty()
for z, c in self.regs:
r += z.allcomprealvar()
return r
def allcompreal_exprlist(self):
r = ExprArray([])
for z, c in self.regs:
r.iadd_noduplicate(z.allcompreal_exprlist())
return r
def allcomprealvar_exprlist(self):
r = ExprArray([])
for z, c in self.regs:
r.iadd_noduplicate(z.allcomprealvar_exprlist())
return r
def rename_var(self, name0, name1):
for x, c in self.regs:
x.rename_var(name0, name1)
for x, c in self.auxs:
x.rename_var(name0, name1)
def rename_map(self, namemap):
for x, c in self.regs:
x.rename_map(namemap)
for x, c in self.auxs:
x.rename_map(namemap)
def getaux(self):
r = Comp.empty()
for x, c in self.auxs:
if c:
r += x
for x, c in self.regs:
if c:
r += x.getaux()
else:
r += x.getauxi()
return r
def getauxi(self):
r = Comp.empty()
for x, c in self.auxs:
if not c:
r += x
for x, c in self.regs:
if c:
r += x.getauxi()
else:
r += x.getaux()
return r
def getauxall(self):
r = Comp.empty()
for x, c in self.auxs:
r += x
for x, c in self.regs:
r += x.getauxall()
return r
def getauxs(self):
return [(x.copy(), c) for x, c in self.auxs]
@property
def aux(self):
return self.getaux()
@property
def auxi(self):
return self.getauxi()
@fcn_substitute
def substitute(self, v0, v1):
"""Substitute variable v0 by v1 (v1 can be compound), in place."""
for x, c in self.regs:
x.substitute(v0, v1)
if not isinstance(v0, Expr):
for x, c in self.auxs:
x.substitute(v0, v1)
if iutil.check_meta_subs_criteria(v0, v1, self):
iutil.substitute(self.meta, v0, v1)
return self
@fcn_substitute
def substitute_whole(self, v0, v1):
"""Substitute variable v0 by v1 (v1 can be compound), in place."""
for x, c in self.regs:
x.substitute_whole(v0, v1)
if not isinstance(v0, Expr):
for x, c in self.auxs:
x.substitute_whole(v0, v1)
if iutil.check_meta_subs_criteria(v0, v1, self):
iutil.substitute_whole(self.meta, v0, v1)
return self
@fcn_substitute
def substitute_aux(self, v0, v1):
"""Substitute variable v0 by v1 (v1 can be compound), and remove auxiliary v0, in place."""
for x, c in self.regs:
x.substitute_aux(v0, v1)
self.auxs = [(x - v0, c) for x, c in self.auxs if not (x - v0).isempty()]
if iutil.check_meta_subs_criteria(v0, v1, self):
iutil.substitute(self.meta, v0, v1)
return self
def remove_present(self, v):
for x, c in self.regs:
x.remove_present(v)
self.auxs = [(x, c) for x, c in self.auxs if not x.ispresent(v)]
def remove_notpresent(self, v):
for x, c in self.regs:
x.remove_notpresent(v)
self.auxs = [(x, c) for x, c in self.auxs if x.ispresent(v)]
def remove_relax(self, v, sn = 1):
for x, c in self.regs:
x.remove_relax(v, sn * (1 if c else -1))
self.auxs = [(x, c) for x, c in self.auxs if not x.ispresent(v)]
def condition(self, b):
"""Condition on random variable b, in place."""
for x, c in self.regs:
x.condition(b)
return self
def symm_sort(self, terms):
"""Sort the random variables in terms assuming symmetry among those terms."""
for x, c in self.regs:
x.symm_sort(terms)
def relax(self, w, gap):
"""Relax real variables in w by gap, in place"""
for x, c in self.regs:
if c:
x.relax(w, gap)
else:
x.relax(w, -gap)
return self
def record_to(self, index):
for x, c in self.regs:
x.record_to(index)
for x, c in self.auxs:
x.record_to(index)
def pack_type_self(self, totype):
if len(self.auxs) == 0:
ctype = self.get_type()
if ctype == totype:
return self
if ctype == RegionType.UNION or ctype == RegionType.INTER:
if len(self.regs) == 1:
self.rtype = totype
return self
self.regs = [(self.copy(), True)]
self.auxs = []
self.rtype = totype
return self
def pack_type(x, totype):
if len(x.getauxs()) == 0:
ctype = x.get_type()
if ctype == totype:
return x.copy()
if ctype == RegionType.UNION or ctype == RegionType.INTER:
if len(x.regs) == 1:
r = x.copy()
r.rtype = totype
return r
return RegionOp(totype, [(x.copy(), True)], [])
def iand_norename(self, other):
self.pack_type_self(RegionType.INTER)
other = RegionOp.pack_type(other, RegionType.INTER)
self.regs += other.regs
self.auxs += other.auxs
self.inter_compress()
return self
def inter_compress(self):
if self.get_type() != RegionType.INTER and self.get_type() != RegionType.UNION:
return
curc = self.get_type() == RegionType.INTER
cons = Region.universe()
for x, c in self.regs:
if c == curc and x.isnormalcons() and x.getauxall().isempty():
cons &= x
x.setuniverse()
self.regs = [(cons, curc)] + self.regs
self.regs = [(x, c) for x, c in self.regs if c != curc or not x.isuniverse()]
def normalcons_sort(self):
if self.get_type() != RegionType.INTER and self.get_type() != RegionType.UNION:
return
curc = self.get_type() == RegionType.INTER
regsf = [(x, c) for x, c in self.regs if c == curc and x.isnormalcons()]
regsb = [(x, c) for x, c in self.regs if c != curc and x.isnormalcons()]
self.regs = regsf + [(x, c) for x, c in self.regs if not x.isnormalcons()] + regsb
def __iand__(self, other):
if isinstance(other, bool):
if not other:
return Region.empty()
return self
other = iutil.ensure_region(other)
if other.isuniverse(canon = True):
return self
self.pack_type_self(RegionType.INTER)
other = RegionOp.pack_type(other, RegionType.INTER)
self.aux_avoid(other)
self.regs += other.regs
self.auxs += other.auxs
self.inter_compress()
return self
def __and__(self, other):
r = self.copy()
r &= other
return r
def __rand__(self, other):
r = self.copy()
r &= other
return r
def __ior__(self, other):
if isinstance(other, bool):
if other:
return Region.universe()
return self
other = iutil.ensure_region(other)
if other.isuniverse(sgn = False, canon = True):
return self
other = RegionOp.pack_type(other, RegionType.UNION)
self.pack_type_self(RegionType.UNION)
self.aux_avoid(other)
self.regs += other.regs
self.auxs += other.auxs
return self
def ior_norename(self, other):
other = iutil.ensure_region(other)
other = RegionOp.pack_type(other, RegionType.UNION)
self.pack_type_self(RegionType.UNION)
self.regs += other.regs
self.auxs += other.auxs
return self
def rior(self, other):
other = iutil.ensure_region(other)
other = RegionOp.pack_type(other, RegionType.UNION)
self.pack_type_self(RegionType.UNION)
self.aux_avoid(other)
self.regs = other.regs + self.regs
self.auxs = other.auxs + self.auxs
return self
def __or__(self, other):
r = self.copy()
r |= other
return r
def __ror__(self, other):
r = self.copy()
r |= other
return r
def append_avoid(self, x, c = True):
y = x.copy()
self.aux_avoid(y)
self.regs.append((y, c))
def __imul__(self, other):
for i in range(len(self.regs)):
self.regs[i][0] *= other
return self
def negate(self):
if self.get_type() == RegionType.UNION:
self.rtype = RegionType.INTER
elif self.get_type() == RegionType.INTER:
self.rtype = RegionType.UNION
self.regs = [(x, not c) for x, c in self.regs]
self.auxs = [(x, not c) for x, c in self.auxs]
return self
def __invert__(self):
r = self.copy()
r = r.negate()
return r
def negateornot(x, c):
if c:
return x.copy()
else:
return ~x
def setuniverse(self):
self.rtype = RegionType.INTER
self.regs = []
def setempty(self):
self.rtype = RegionType.UNION
self.regs = []
def universe():
return RegionOp(RegionType.INTER, [], [])
def empty():
return RegionOp(RegionType.UNION, [], [])
#return ~Region.universe()
def z3(self, vardict = None):
if z3 is None:
return None
if vardict is None:
vardict = iutil.z3_vardict(self.rvs, self.aux, self.auxi, self.reals)
r = []
for x, c in self.regs:
t = x.z3(vardict)
if not c:
t = z3.Not(t)
r.append(t)
if len(r) == 1:
r = r[0]
else:
if self.get_type() == RegionType.UNION:
r = z3.Or(r)
elif self.get_type() == RegionType.INTER:
r = z3.And(r)
for x, c in self.auxs:
if not c:
# raise ValueError("Universally-quantified random variables are not supported for Z3. Use another solver.")
r = z3.ForAll([y.z3(vardict) for y in x], r)
else:
r = z3.Exists([y.z3(vardict) for y in x], r)
return r
def sum_minkowski(self, other):
cs = self.copy()
cs.distribute()
other = other.copy()
other.distribute()
# warnings.warn("Minkowski sum of union or intersection regions is unsupported.", RuntimeWarning)
auxs = [(x.copy(), c) for x, c in cs.getauxs() + other.getauxs()]
if cs.get_type() == RegionType.UNION:
return RegionOp(RegionType.UNION, [(x.sum_minkowski(other), c) for x, c in cs.regs], auxs)
if other.get_type() == RegionType.UNION:
return RegionOp(RegionType.UNION, [(cs.sum_minkowski(x), c) for x, c in other.regs], auxs)
# The following are technically wrong
if cs.get_type() == RegionType.INTER:
return RegionOp(RegionType.INTER, [(x.sum_minkowski(other), c) for x, c in cs.regs], auxs)
if other.get_type() == RegionType.INTER:
return RegionOp(RegionType.INTER, [(cs.sum_minkowski(x), c) for x, c in other.regs], auxs)
return cs.copy()
def implicate(self, other, skip_simplify = False):
other = iutil.ensure_region(other)
cs = self
cs |= ~other
return cs
def implicate_norename(self, other, skip_simplify = False):
other = iutil.ensure_region(other)
cs = self
cs.ior_norename(~other)
return cs
def implicated(self, other, skip_simplify = False):
other = iutil.ensure_region(other)
r = self.copy()
r = r.implicate(other, skip_simplify)
return r
def sum_entrywise(self, other):
r = RegionOp(self.rtype, [], [])
for ((x1, c1), (x2, c2)) in zip(self.regs, other.regs):
r.regs.append((x1.sum_entrywise(x2), c1))
for ((x1, c1), (x2, c2)) in zip(self.auxs, other.auxs):
r.regs.append((x1.interleaved(x2), c1))
return r
def corners_optimum(self, w, sn):
if self.get_type() == RegionType.UNION:
r = self.copy()
r.regs = []
did = False
for x, c in self.regs:
if not c:
r.append_avoid(x.copy(), c)
continue
t = x.corners_optimum(w, sn)
if t.isuniverse():
r.append_avoid(x.copy(), c)
else:
r.append_avoid(t, c)
did = True
if did:
return r
return Region.universe()
elif self.get_type() == RegionType.INTER:
r = RegionOp.union([])
r.auxs = [(x.copy(), c) for x, c in self.auxs]
for x, c in self.regs:
if not c:
continue
t = x.corners_optimum(w, sn)
if t.isuniverse():
continue
else:
tt = self.copy()
for i in range(len(tt.regs)):
if tt.regs[i] is x:
tt.regs[i] = t
break
r.append_avoid(tt)
if len(r.regs) > 0:
return r
return Region.universe()
return Region.universe()
def balance(self, v = None, w = None, skip_simplify = False):
if w is None:
w = self.allcomprv()
for x, c in self.regs:
x.balance(v, w, skip_simplify = True)
if not skip_simplify:
self.simplify()
return self
def sign_present(self, term):
r = [False] * 2
for x, c in self.regs:
t = x.sign_present(term)
if not c:
t.reverse()
r[0] |= t[0]
r[1] |= t[1]
if r[0] and r[1]:
break
return r
def substitute_sign(self, v0, v1s):
r = [False] * 2
t = [False] * 2
for x, c in self.regs:
if c:
t = x.substitute_sign(v0, v1s)
else:
t = x.substitute_sign(v0, list(reversed(v1s)) if v1s else v1s)
t.reverse()
r[0] |= t[0]
r[1] |= t[1]
return r
def substitute_duplicate(self, v0, v1s):
for x, c in self.regs:
x.substitute_duplicate(v0, v1s)
def flatten_minmax(self, term, sn, bds):
# for x, c in self.regs:
# x.flatten_minmax(term, sgn, bds)
latex_name = term.tostring(style = "latex")
v1s = [Expr.real(self.name_avoid(str(term) + "_L") + "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + latex_name),
Expr.real(self.name_avoid(str(term) + "_U") + "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + latex_name)]
# v1s = [Expr.real(self.name_avoid(str(term) + "_L")),
# Expr.real(self.name_avoid(str(term) + "_U"))]
sn_present = self.substitute_sign(Expr.fromterm(term), v1s)
if sn < 0:
sn_present.reverse()
v1s.reverse()
if sn_present[1]:
treg = Region.universe()
for b in bds:
treg &= v1s[1] * sn <= b * sn
self &= treg
# print(self)
# print(v1s[1])
self.eliminate(v1s[1])
if sn_present[0]:
tself = RegionOp.pack_type(self.copy(), RegionType.UNION)
if any(not c for x, c in tself.regs):
tself = RegionOp.union([tself])
self.copy_(RegionOp.union([]))
for t2 in tself:
if t2.ispresent(v1s[0]):
for b in bds:
t3 = t2.copy()
t3.substitute(v1s[0], b)
self |= t3
else:
self |= t2
# print("AFTER FLATTEN REGTERM")
# print(self)
# print(sn_present)
# print()
return self
def lowest_present(self, v, sn):
ps = []
cpres = False
for x, c in self.regs:
if x.ispresent(v):
ps.append(x)
cpres = c
if len(ps) == 0:
return None
if len(ps) == 1:
t = ps[0].lowest_present(v, sn ^ (not cpres))
if t is not None:
return t
if sn:
return self
return None
def term_sn_present(self, term, sn):
sn *= term.sn
sn_present = self.substitute_sign(Expr.fromterm(term), None)
if sn < 0:
sn_present.reverse()
return sn_present[0]
def term_sn_present_both(self, term):
sn_present = self.substitute_sign(Expr.fromterm(term), None)
return sn_present[0] and sn_present[1]
def flatten_regterm(self, term, isimp = True, minmax_elim = False):
if term.reg is None:
return
# print("FLAT " + str(minmax_elim))
if minmax_elim:
treg, tsgn, tbds = term.get_reg_sgn_bds()
if treg.isuniverse() and tsgn != 0:
self.flatten_minmax(term, tsgn, tbds)
return
self.simplify_quick()
sn = term.sn
latex_name = term.tostring(style = "latex")
v1s = [Expr.real(self.name_avoid(str(term) + "_L") + "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + latex_name),
Expr.real(self.name_avoid(str(term) + "_U") + "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + latex_name)]
sn_present = self.substitute_sign(Expr.fromterm(term), v1s)
if sn < 0:
sn_present.reverse()
v1s.reverse()
if sn == 0:
sn_present[1] = False
sn_present[0] = True
if sn_present[1]:
term2 = term.copy()
self.aux_avoid(term2.reg)
rsb = term2.get_reg_sgn_bds()
if (PsiOpts.settings["flatten_distribute"] and rsb is not None and not rsb[0].imp_present() and isimp
and (PsiOpts.settings["flatten_distribute_multi"] or len(rsb[2]) == 1)):
treg, tsgn, tbds = rsb
lpres = self.lowest_present(v1s[1], True)
lpres.substitute_duplicate(v1s[1], tbds)
taux = treg.aux.copy()
treg.aux = Comp.empty()
lpres &= treg
lpres.eliminate(taux)
else:
reg2 = term.reg.broken_present(Expr.fromterm(term), flipped = True)
reg2.substitute(Expr.fromterm(term), v1s[1])
if isimp:
self |= reg2
else:
self &= ~reg2
if sn_present[0]:
reg2 = term.reg.copy()
reg2.substitute(Expr.fromterm(term), v1s[0])
if isimp:
self.implicate(reg2)
else:
self &= reg2
# print("AFTER FLATTEN REGTERM")
# print(self)
# print(sn_present)
# print()
return self
def flatten_ivar(self, ivar, isimp = True):
if ivar.reg is None:
return
newvar = Comp([ivar.copy_noreg()])
reg2 = ivar.reg.copy()
self.aux_avoid(reg2)
if not ivar.reg_det:
newindep = Expr.Ic(reg2.getaux() + newvar,
self.allcomprv().reg_excluded() - self.getaux() - reg2.allcomprv() - newvar,
reg2.allcomprv_noaux() - newvar).simplified_quick()
if not newindep.iszero():
reg2.iand_norename(newindep == 0)
#self.implicate_norename(newindep == 0)
if isimp:
#self.implicate_norename(reg2.exists(newvar), skip_simplify = True)
self.implicate_norename(reg2, skip_simplify = True)
else:
self &= reg2
self.substitute(Comp([ivar]), newvar)
return self
def isregtermpresent(self):
for x, c in self.regs:
if x.isregtermpresent():
return True
return False
def regtermmap(self, cmap, recur):
for x, c in self.regs:
x.regtermmap(cmap, recur)
aux = Comp.empty()
for x, c in self.auxs:
aux += x
(H(aux) >= 0).regtermmap(cmap, recur)
def regterm_split(self, term):
if not (self.term_sn_present_both(term) or term.reg_outer is not None):
return False
# sn = term.sn
# terml = term.copy()
# terml.substitute(term.x[0], Comp.real(self.name_avoid(str(term) + "_L0")))
# if terml.reg_outer is not None:
# terml.reg = terml.reg_outer
# terml.reg_outer = None
# termr = term.copy()
# termr.substitute(term.x[0], Comp.real(self.name_avoid(str(term) + "_R0")))
# termr.reg_outer = None
# v1s = [Expr.fromterm(terml), Expr.fromterm(termr)]
# if sn < 0:
# v1s.reverse()
v1s = [Expr.fromterm(term.upper_bound(name = self.name_avoid(str(term) + "_L0"))),
Expr.fromterm(term.lower_bound(name = self.name_avoid(str(term) + "_R0")))]
self.substitute_sign(Expr.fromterm(term), v1s)
return True
def flatten(self, minmax_elim = False):
verbose = PsiOpts.settings.get("verbose_flatten", False)
write_pf_enabled = PsiOpts.settings.get("proof_enabled", False) and PsiOpts.settings.get("proof_step_expand_def", False)
if write_pf_enabled:
prevself = self.copy()
did = True
didall = False
while did:
did = False
regterms = {}
# self.regtermmap(regterms, True)
# for (name, term) in regterms.items():
# regterms_in = {}
# term.reg.regtermmap(regterms_in, False)
# if not regterms_in:
regterms_exc = {}
self.regtermmap(regterms, False)
for (name, term) in regterms.items():
term.reg.regtermmap(regterms_exc, True)
for (name, term) in regterms.items():
if name not in regterms_exc:
if isinstance(term, IVar):
pass
else:
if self.regterm_split(term):
did = True
break
if did:
continue
for cpass, (name, term) in itertools.product(range(2), regterms.items()):
if name not in regterms_exc:
if isinstance(term, IVar):
if cpass == 1:
continue
else:
if cpass == 0 and self.term_sn_present(term, 1):
continue
if verbose:
print("========= flatten op ========")
print(self)
print("========= term ========")
print(term)
print("========= region ========")
print(term.reg)
if isinstance(term, IVar):
self.flatten_ivar(term)
else:
self.flatten_regterm(term, minmax_elim = minmax_elim)
did = True
didall = True
if verbose:
print("========= to ========")
print(self)
break
if write_pf_enabled:
if didall:
# pf = ProofObj.from_region(prevself, c = "Expand definitions")
# pf += ProofObj.from_region(self, c = "Expanded definitions to")
pf = ProofObj.from_region(("equiv", prevself, self), c = "Expand definitions:")
PsiOpts.set_setting(proof_add = pf)
return self
def tosimple(self):
r = Region.universe()
if self.get_type() == RegionType.UNION:
if len(self.regs) == 0:
r = Region.empty()
elif len(self.regs) == 1 and self.regs[0][1]:
r = self.regs[0][0].tosimple()
if r is None:
return None
else:
return None
elif self.get_type() == RegionType.INTER:
for x, c in self.regs:
if not c:
return None
t = x.tosimple()
if t is None or t.imp_present():
return None
r.iand_norename(t)
for x, c in self.auxs:
if not c:
return None
r.eliminate(x)
return r
def tosimple_safe(self):
if not self.getauxi().isempty():
return None
for x, c in self.regs:
if x.aux_present():
return None
return self.tosimple()
def tonormal_safe(self):
t = self.tosimple_safe()
if t is not None:
return t
return self.copy()
def level(self, mode = ""):
"""The level of a region in the linear entropy hierarchy.
"""
r = Level()
for x, c in self.regs:
t = x.level(mode)
if not c:
t = ~t
r = r | t
for x, c in self.auxs:
r = r.exists(not c)
return r
def complexity(self):
return sum(x.complexity() for x, c in self.regs) + len(self.auxs) * 100 + sum(len(x) for x, c in self.auxs) * 100
def sorting_priority(self):
return self.complexity()
def var_neighbors(self, v):
r = v.copy()
for x, c in self.regs:
r += x.var_neighbors(v)
return r
def one_flipped(self):
return None
def distribute(self, force_split = False):
"""Expand to a single union layer.
"""
# print(self)
# if self.get_type() == RegionType.UNION:
# print("UNION")
# if self.get_type() == RegionType.INTER:
# print("INTER")
# print()
if self.get_type() == RegionType.UNION:
tregs = []
for x, c in self.regs:
if force_split and x.get_type() == RegionType.NORMAL:
if x.imp_present():
x = x.toregionop_split()
elif not c:
x = x.toregionop_split(force_split = True)
if force_split or isinstance(x, RegionOp):
if not c:
x = x.negate()
c = True
x.distribute(force_split = force_split)
if c and x.get_type() == RegionType.UNION:
tregs += x.regs
self.auxs = x.getauxs() + self.auxs
else:
tregs.append((x, c))
self.regs = tregs
return self
if self.get_type() == RegionType.INTER:
tregs = [(Region.universe(), True)]
self.rtype = RegionType.UNION
for x, c in self.regs:
if force_split and x.get_type() == RegionType.NORMAL:
if x.imp_present():
x = x.toregionop_split()
elif not c:
x = x.toregionop_split(force_split = True)
if force_split or isinstance(x, RegionOp):
if not c:
x = x.negate()
c = True
x.distribute(force_split = force_split)
if c and x.get_type() == RegionType.UNION:
tregs2 = []
for y, cy in x.regs:
for a, ca in tregs:
tregs2.append((RegionOp.negateornot(a, ca) & RegionOp.negateornot(y, cy), True))
tregs = tregs2
self.auxs = x.getauxs() + self.auxs
else:
tregs = [(RegionOp.negateornot(a, ca) & RegionOp.negateornot(x, c), True) for a, ca in tregs]
self.auxs = x.getauxs() + self.auxs
self.regs = tregs
return self
return self
def tounion(self):
r = self.copy()
r = r.distribute()
if r.get_type() == RegionType.UNION:
return r
return RegionOp.pack_type(r, RegionType.UNION)
def aux_appearance(self, curc):
allcomprv = self.allcomprv() - self.getauxall()
r = []
for x, c in self.regs:
if isinstance(x, RegionOp):
r += x.aux_appearance(curc ^ (not c))
else:
xallcomprv = x.allcomprv() - x.getauxall()
if not x.auxi.isempty():
r.append((x.auxi.copy(), curc ^ (not c) ^ True, xallcomprv.copy()))
if not x.aux.isempty():
r.append((x.aux.copy(), curc ^ (not c), xallcomprv.copy()))
rt = []
for x, c in self.auxs[::-1]:
rt.append((x.copy(), curc ^ (not c), allcomprv.copy()))
allcomprv += x
r += rt[::-1]
return r
def aux_remove(self):
self.auxs = []
for x, c in self.regs:
if isinstance(x, RegionOp):
x.aux_remove()
else:
x.aux = Comp.empty()
x.auxi = Comp.empty()
return self
def aux_collect(self):
"""Collect auxiliaries to outermost layer.
"""
iaux = []
for x, c in self.regs:
if isinstance(x, RegionOp):
x.aux_collect()
iaux += [(x2, c2 ^ (not c)) for x2, c2 in x.auxs]
x.auxs = []
else:
iaux += [(x2, c2 ^ (not c)) for x2, c2 in x.getauxs()]
x.aux = Comp.empty()
x.auxi = Comp.empty()
self.auxs = iaux + self.auxs
def break_imp(self):
for i in range(len(self.regs)):
if self.regs[i][0].get_type() == RegionType.NORMAL and self.regs[i][0].imp_present():
rc = self.regs[i][0].consonly()
ri = self.regs[i][0].imp_flippedonly()
ri.aux = Comp.empty()
r = RegionOp.union([rc])
if not ri.isuniverse():
r = r.implicated(ri)
if not self.regs[i][0].auxi.isempty():
r.auxs = [(self.regs[i][0].auxi.copy(), False)]
self.regs[i] = (r, self.regs[i][1])
if isinstance(self.regs[i][0], RegionOp):
self.regs[i][0].break_imp()
def break_imp_old(self):
for i in range(len(self.regs)):
if self.regs[i][0].get_type() == RegionType.NORMAL and self.regs[i][0].imp_present():
auxs = self.regs[i][0].getauxs()
rc = self.regs[i][0].consonly()
rc.aux = Comp.empty()
ri = self.regs[i][0].imp_flippedonly()
ri.aux = Comp.empty()
r = RegionOp.union([rc])
if not ri.isuniverse():
r = r.implicated(ri)
r.auxs = auxs
self.regs[i] = (r, self.regs[i][1])
if isinstance(self.regs[i][0], RegionOp):
self.regs[i][0].break_imp()
def from_region(x):
if isinstance(x, RegionOp):
return x.copy()
r = RegionOp.union([x])
r.break_imp()
if len(r.regs) == 1 and r.regs[0][1] and isinstance(r.regs[0][0], RegionOp):
return r.regs[0][0]
return r
def break_present(self, w, flipped = True):
for i in range(len(self.regs)):
x, c = self.regs[i]
if isinstance(x, RegionOp):
x.break_present(w, flipped)
else:
self.regs[i] = (x.broken_present(w, flipped), c)
return self
def broken_present(self, w, flipped = True):
r = self.copy()
r = r.break_present(w, flipped)
return r
def aux_clean(self):
auxs = self.auxs
self.auxs = []
for x, c in auxs:
if len(self.auxs) > 0 and self.auxs[-1][1] == c:
self.auxs[-1]= (self.auxs[-1][0] + x, self.auxs[-1][1])
else:
self.auxs.append((x, c))
def emptycons_present(self):
for x, c in self.regs:
cons_present = False
if x.get_type() == RegionType.INTER:
for y, cy in x.regs:
if not cy:
cons_present = True
break
elif x.get_type() == RegionType.NORMAL:
if not c:
cons_present = True
if not cons_present:
return True
return False
def sign_has_component(self, sn):
for x, c in self.regs:
if x.get_type() == RegionType.NORMAL:
if c == sn:
return True
else:
if x.sign_has_component(sn ^ (not c)):
return True
return False
def sign_only_inplace(self, sn, outermost = False):
t = []
for x, c in self.regs:
if x.get_type() == RegionType.NORMAL:
if c == sn:
t.append((x, c))
elif outermost:
if c == sn and x.sign_has_component(sn ^ (not c)):
x.sign_only_inplace(sn ^ (not c))
t.append((x, c))
else:
x.sign_only_inplace(sn ^ (not c))
t.append((x, c))
self.regs = t
self.auxs = [(x, c) for x, c in self.auxs if c == sn]
def sign_only(self, sn, outermost = False):
r = self.copy()
r.sign_only_inplace(sn, outermost)
return r
def consonly(self):
return self.sign_only(True, outermost = True)
def imponly(self):
return self.sign_only(False, outermost = True)
def process_pm(self, fcn1, fcn0):
for x, c in self.regs:
if c:
if x.get_type() == RegionType.NORMAL:
x = fcn1(x)
else:
x.process_pm(fcn1, fcn0)
else:
if x.get_type() == RegionType.NORMAL:
x = fcn0(x)
else:
x.process_pm(fcn0, fcn1)
def presolve_strengthen(self):
def fcn1(r):
r.presolve_process(relax = True)
return r
def fcn0(r):
r.presolve_process(relax = False)
return r
self.process_pm(fcn1, fcn0)
def presolve(self):
self.break_imp()
self.simplify_quick(zero_group = 1)
self.flatten(minmax_elim = PsiOpts.settings["flatten_minmax_elim"])
self.break_imp()
self.presolve_strengthen()
aux_ap = self.aux_appearance(True)
self.aux_remove()
self.distribute()
if False:
t = Comp.empty()
for x, c, ccomp in aux_ap:
t += x
t = self.allcomprv() - t
if not t.isempty():
#aux_ap = [(t, False, Comp.empty())] + aux_ap
aux_ap.append((t, False, Comp.empty()))
#self.inter_compress()
self.normalcons_sort()
for x, c in self.regs:
if isinstance(x, RegionOp):
x.inter_compress()
self.auxs = [(x, c) for x, c, ccomp in aux_ap]
self.aux_clean()
if not self.emptycons_present():
self.regs.append((Region.empty(), True))
return self
def to_cause_consequence(self):
cs = ~self
cs.presolve()
r = []
#allauxs = cs.auxs
auxi = cs.getaux()
for x, c in cs.regs:
#allcomprv = x.allcomprv()
#cur_auxs_incomp = RegionOp.auxs_incomp(allauxs, allcomprv)
req = Region.universe()
cons = []
if x.get_type() == RegionType.INTER:
for y, cy in x.regs:
if cy:
req &= y
else:
cons.append(y)
elif x.get_type() == RegionType.NORMAL:
if c:
if x.isempty():
continue
req &= x
else:
cons.append(x)
ccons = None
if len(cons) == 1:
ccons = cons[0]
else:
ccons = RegionOp.union(cons)
cauxi = auxi.inter(x.allcomp())
# cauxi = x.auxi
# cauxi = x.aux
r.append((req, ccons, cauxi))
return r
def get_var_avoid(self, a):
r = None
for x, c in self.regs:
t = x.get_var_avoid(a)
if r is None:
r = t
elif t is not None:
r = r.inter(t)
return r
def auxs_icreg(auxs, othercomp, exclcomp):
r = Region.universe()
ccomp = Comp.empty()
for a, c in reversed(auxs):
if not c:
r.iand_norename(Region.Ic(a - exclcomp, othercomp, ccomp))
ccomp += a
return r
def auxs_incomp(auxs, x):
r = [(a.inter(x), c) for a, c in auxs]
return [(a, c) for a, c in r if not a.isempty()]
# r2 = []
# for a, c in r:
# if not a.isempty():
# if len(r2) == 0 or r2[-1][1] != c:
# r2.append((a, c))
# else:
# r2[-1] = (r2[-1][0] + a, c)
# return r2
def calc_req_cons(self):
self.presolve()
n = len(self.regs)
r = []
for i in range(n):
x, c = self.regs[i]
req = Region.universe()
cons = []
if x.get_type() == RegionType.INTER:
for y, cy in x.regs:
if cy:
req &= y
else:
cons.append(y)
elif x.get_type() == RegionType.NORMAL:
if c:
req &= x
else:
cons.append(x)
r.append((req, cons))
return r
def get_req_cons(self):
cs = self.copy()
r = cs.calc_req_cons()
return (r, cs.auxs)
def check_getaux_inplace(self, must_include = None, single_include = None, hint_pair = None, hint_aux = None, hint_aux_avoid = None, max_iter = None, leaveone = None, strengthen = None, get_info = None):
verbose = PsiOpts.settings.get("verbose_auxsearch", False)
verbose_step = PsiOpts.settings.get("verbose_auxsearch_step", False)
verbose_op = PsiOpts.settings.get("verbose_auxsearch_op", False)
verbose_op_step = PsiOpts.settings.get("verbose_auxsearch_op_step", False)
verbose_op_detail = PsiOpts.settings.get("verbose_auxsearch_op_detail", False)
verbose_op_detail2 = PsiOpts.settings.get("verbose_auxsearch_op_detail2", False)
ignore_must = PsiOpts.settings["ignore_must"]
forall_multiuse = PsiOpts.settings["forall_multiuse"]
forall_multiuse_numsave = PsiOpts.settings["forall_multiuse_numsave"]
auxsearch_local = PsiOpts.settings["auxsearch_local"]
init_leaveone = PsiOpts.settings["init_leaveone"]
auxsearch_aux_strengthen = PsiOpts.settings["auxsearch_aux_strengthen"]
if leaveone is None:
leaveone = PsiOpts.settings["auxsearch_leaveone"]
save_res = auxsearch_local
if max_iter is None:
max_iter = PsiOpts.settings["auxsearch_max_iter"]
if hint_aux is None:
hint_aux = []
if hint_aux_avoid is None:
hint_aux_avoid = []
if must_include is None:
must_include = Comp.empty()
if strengthen is None:
strengthen = PsiOpts.settings["auxsearch_strengthen"]
casesteplimit = PsiOpts.settings["auxsearch_op_casesteplimit"]
caselimit = PsiOpts.settings["auxsearch_op_caselimit"]
write_pf_enabled = PsiOpts.settings.get("proof_enabled", False)
write_pf_repeat_claim = PsiOpts.settings.get("proof_repeat_implicant", False)
yield_one = write_pf_enabled and PsiOpts.settings.get("proof_yield_one", False)
if verbose_op:
print("========= aux search op ========")
print(self)
self.presolve()
if verbose_op:
print("========= expanded ========")
print(self)
print()
csallcomprv = self.allcomprv()
csaux = Comp.empty()
csauxi = Comp.empty()
for a, ca in self.auxs:
if not ca:
csauxi += a
else:
csaux += a
write_pf_twopass = write_pf_enabled and not csaux.isempty()
if write_pf_twopass:
PsiOpts.settings["proof_enabled"] = False
write_pf_enabled_inner = PsiOpts.settings.get("proof_enabled", False)
max_numupdate_one = write_pf_enabled_inner and PsiOpts.settings.get("auxsearch_max_numupdate_one_if_proof", False)
allauxs = self.auxs
csnonaux = csallcomprv - csaux - csauxi
if not csnonaux.isempty():
allauxs.append((csnonaux, False))
csauxiall = csauxi + csnonaux
csallcomprv = csaux + csauxiall
csaux_id = IVarIndex()
csaux_id.record(csaux)
csauxiall_id = IVarIndex()
csauxiall_id.record(csauxiall)
csall_id = IVarIndex()
csall_id.record(csallcomprv)
csaux_is = [csall_id.get_index(a) for a in csaux.varlist]
csauxiall_is = [csall_id.get_index(a) for a in csauxiall.varlist]
csauxidep = Comp.empty()
nvar = len(csallcomprv)
n = len(self.regs)
nreqcheck = 0
nfinal = 0
depgraph = [[False] * nvar for j in range(nvar)]
xallcomprv = []
xconscomprv = []
xreq = []
xcons = []
xmultiuse = []
xleaveone = []
xaux = []
xauxi = []
xaux_avoid = []
xoneuse_aux = []
xcforall = []
xauxs_incomp = []
xreqcomprv = []
xconsonly = []
xconsonly_init = []
init_reg = Region.universe()
for i in range(n):
x, c = self.regs[i]
allcomprv = x.allcomprv()
cur_auxs_incomp = RegionOp.auxs_incomp(allauxs, allcomprv)
req = Region.universe()
cons = []
if x.get_type() == RegionType.INTER:
for y, cy in x.regs:
if cy:
req &= y
else:
cons.append(y)
elif x.get_type() == RegionType.NORMAL:
if c:
req &= x
else:
cons.append(x)
conscomprv = Comp.empty()
for con in cons:
conscomprv += con.allcomprv()
cur_multiuse = forall_multiuse
aux = Comp.empty()
#auxi = csauxi.inter(conscomprv)
auxi = csauxi.inter(allcomprv)
aux_avoid = []
cavoid = Comp.empty()
cavoidmask = 0
cforall_c = Comp.empty()
cforall = Comp.empty()
for a, ca in allauxs:
b = a.inter(allcomprv)
bmask = csall_id.get_mask(b)
if not b.isempty():
if ca:
aux += b
aux_avoid.append((b.copy(), cavoid.copy()))
for j in range(nvar):
if bmask & (1 << j) != 0:
for j2 in range(nvar):
if cavoidmask & (1 << j2) != 0:
depgraph[j2][j] = True
if not cforall_c.isempty():
cur_multiuse = False
csauxidep += cforall_c
else:
cforall_c += b.inter(auxi)
cavoid += b
cavoidmask |= bmask
#cavoid += a
if len(cons) == 0:
cur_multiuse = True
cforall_c = Comp.empty()
for a, ca in allauxs[::-1]:
if ca:
if not a.inter(allcomprv).isempty():
cforall = cforall_c.copy()
else:
cforall_c += a
cur_leaveone = leaveone and cur_multiuse and len(cons) == 0
innermost_auxi = Comp.empty()
if auxsearch_aux_strengthen and len(cur_auxs_incomp) and not cur_auxs_incomp[0][1]:
innermost_auxi = cur_auxs_incomp[0][0].copy() - csnonaux
cons2 = cons
cons = []
for x in cons2:
y = x.copy()
if not innermost_auxi.isempty():
innermost_auxi_int = innermost_auxi.inter(y.allcomprv())
if not innermost_auxi_int.isempty():
y.aux += innermost_auxi_int
y.aux_strengthen(req.allcomprv() + csnonaux)
y.aux -= innermost_auxi_int
for v in aux:
y.substitute(v, Comp.rv(v.get_name() + "_R" + str(i)))
cons.append(y)
oneuse_aux = MHashSet()
oneuse_aux.add(None)
if req.isuniverse() and aux.isempty():
pass
else:
nreqcheck += 1
if len(cons) == 0:
nfinal += 1
if init_leaveone and aux.isempty() and csnonaux.super_of(auxi):
ofl = req.one_flipped()
if ofl is not None:
init_reg &= ofl.add_meta("pf_note", ["trivial otherwise"])
cur_consonly = req.isuniverse() and aux.isempty()
cur_consonly_init = cur_consonly and len(cons) == 1 and csnonaux.super_of(auxi)
# if cur_consonly_init:
# init_reg &= cons[0]
xallcomprv.append(allcomprv)
xconscomprv.append(conscomprv)
xreq.append(req)
xcons.append(cons)
xmultiuse.append(cur_multiuse)
xleaveone.append(cur_leaveone)
xaux.append(aux)
xauxi.append(auxi)
xaux_avoid.append(aux_avoid)
xoneuse_aux.append(oneuse_aux)
xcforall.append(cforall)
xauxs_incomp.append(cur_auxs_incomp)
xreqcomprv.append(req.allcomprv())
xconsonly.append(cur_consonly)
xconsonly_init.append(cur_consonly_init)
if verbose_op:
print("========= #" + iutil.strpad(str(i), 3, " requires ========"))
print(req)
if len(cons) > 0:
print("========= consequences ========")
print("\nOR\n".join(str(con) for con in cons))
if not aux.isempty() or not auxi.isempty():
print("========= auxiliary ========")
print(" ".join(("|" if c else "&") + str(a) for a, c in cur_auxs_incomp))
print("Multiuse = " + str(cur_multiuse), ", Leave one = " + str(cur_leaveone))
print("Cons only = " + str(cur_consonly))
print()
hint_aux_avoid = hint_aux_avoid + self.get_aux_avoid_list()
mustcomp = Comp.empty()
if not ignore_must:
for a in csauxi:
cmarkers = a.get_markers()
cdict = {v: w for v, w in cmarkers}
if cdict.get("mustuse", False):
mustcomp += a
rcases = MHashSet()
rcases.add((init_reg, [False] * n))
oneuse_added = [False] * n
res = MHashList()
rcases_hashset = set()
#rcases_hashset.add(hash(rcases))
oneuse_set = MHashSet()
oneuse_set.add(([None] * len(csaux), []))
max_iter_pow = 4
#cur_max_iter = 200
cur_max_iter = 800
max_yield_pow = 3
#cur_max_yield = 5
cur_max_yield = 16
max_yield_add = 0
max_numupdate = 1000000000000
if max_numupdate_one:
max_numupdate = 1
if nreqcheck <= 1:
cur_max_iter *= 1000000000
cur_max_yield *= 1000000000
if yield_one:
cur_max_yield = 1
max_yield_pow = 1
max_yield_add = 1
did = True
prev_did = True
caselimit_warned = False
caselimit_reached = False
if verbose_op:
print("========= markers ========")
for a in csallcomprv:
cmarkers = a.get_markers()
if len(cmarkers) > 0:
print(iutil.strpad(str(a), 6, " : " + str(cmarkers)))
print("Must use: " + str(mustcomp))
print("csnonaux: " + str(csnonaux))
print("csauxidep: " + str(csauxidep))
print("Max iter: " + str(cur_max_iter) + " Max yield: " + str(cur_max_yield))
print("========= init region ========")
print(init_reg)
print("========= begin search ========")
cnstep = 0
consonly_reg = None
#while did and (max_iter <= 0 or cur_max_iter < max_iter):
while (did or cnstep <= 1) and (max_iter <= 0 or cur_max_iter < max_iter):
if PsiOpts.is_timer_ended():
break
cnstep += 1
prev_did = did
did = False
#rcases3 = rcases
nonsimple_did = False
numupdate = 0
if cnstep >= 10:
max_numupdate = 1000000000000
for i in range(n):
if len(rcases) == 0:
break
cur_consonly = xconsonly[i]
cur_consonly_init = xconsonly_init[i]
if cnstep > 1 and cur_consonly:
continue
if cnstep == 1 and not cur_consonly:
continue
x, c = self.regs[i]
allcomprv = xallcomprv[i]
conscomprv = xconscomprv[i]
req = xreq[i]
cons = xcons[i]
cur_multiuse = xmultiuse[i]
cur_leaveone = xleaveone[i]
aux = xaux[i]
auxi = xauxi[i]
aux_avoid = xaux_avoid[i]
oneuse_aux = xoneuse_aux[i]
cforall = xcforall[i]
cur_auxs_incomp = xauxs_incomp[i]
reqcomprv = xreqcomprv[i]
if not cur_consonly:
if not aux.isempty() or len(cons) > 0 or cur_leaveone:
nonsimple_did = True
if len(rcases) > caselimit:
caselimit_reached = True
if not caselimit_warned:
caselimit_warned = True
warnings.warn("Max number of cases " + str(caselimit) + " reached. May give false reject.", RuntimeWarning)
if len(cons) >= 2 and caselimit_reached:
continue
rcases2 = rcases
rcases = MHashSet()
for rcase_tuple in rcases2:
rcase = rcase_tuple[0]
rcase_vis = rcase_tuple[1]
rcur = None
if aux.isempty():
rcur = req.implicated(rcase)
else:
rcase_t = rcase
if strengthen:
rcase_t = rcase_t.copy()
rcase_t.aux += csauxi.inter(rcase_t.allcomprv())
rcase_t.simplify_aux_eq()
rcase_t.aux = Comp.empty()
#rcur = req.implicated(rcase).exists(aux).forall(csnonaux)
# rcur = req.implicated(rcase_t).exists(aux).forall(csauxiall)
rcur = req.exists(aux).implicated(rcase_t).forall(csauxiall)
rcases_toadd = MHashSet()
rcases_toadd.add((rcase.copy(), rcase_vis[:]))
cur_yield = 0
for oneaux in reversed(oneuse_set):
auxvis = oneaux[1]
if i in auxvis:
continue
auxmasks = oneaux[0]
auxlist = [(None if a is None else csauxiall.from_mask(a)) for a in auxmasks]
mustleft = Comp.empty()
if nfinal == 1 and len(cons) == 0 and not mustcomp.isempty():
cmustcomp = Comp.empty()
for i2 in auxvis:
cmustcomp += xauxi[i2].inter(mustcomp)
if not cmustcomp.isempty():
auxlistall = sum((a for a in auxlist if a is not None), Comp.empty())
mustleft = cmustcomp - auxlistall
if aux.isempty() and not mustleft.isempty():
#print("MUST NOT")
#print(str(mustcomp) + " " + str(cmustcomp) + " " + str(auxlistall))
#print(rcur)
continue
rcur2 = rcur.copy()
if verbose_op_detail:
print("========= #" + iutil.strpad(str(i), 3, " step ========"))
print("SUB " + " ".join(str(csaux[j]) + ":" + str(auxlist[j]) for j in range(len(csaux)) if auxlist[j] is not None))
print("DEP " + " ".join(str(i2) for i2 in auxvis))
#print("=====================")
#print(rcur2)
if verbose_op_detail2:
print("========= #" + iutil.strpad(str(i), 3, " before indep ===="))
print(rcur2)
#clcomp = csnonaux.copy()
clcomp = csauxiall - csauxidep
#clcomp = rcur2.imp_flippedonly().allcomprv() + csnonaux
#for i2 in auxvis:
# clcomp -= xallcomprv[i2]
#for i2 in reversed(tsorted):
# if i2 != i and oneauxs[i2] is None:
# clcomp += xallcomprv[i2]
for i2 in auxvis:
#tauxi = xauxi[i2] - clcomp
#tcond = xallcomprv[i2] - xauxi[i2]
#rcur2.iand_norename(Region.Ic(tauxi, clcomp, tcond).imp_flipped())
#print("TREG")
#print(Expr.Ic(tauxi, clcomp, tcond))
rcur2.iand_norename(RegionOp.auxs_icreg(xauxs_incomp[i2], clcomp - xallcomprv[i2], clcomp + xreqcomprv[i2]).imp_flipped())
if verbose_op_detail2:
treg2 = RegionOp.auxs_icreg(xauxs_incomp[i2], clcomp - xallcomprv[i2], clcomp)
if not treg2.isuniverse():
print("========= #" + iutil.strpad(str(i), 3, " indep " + str(i2) + " ====="))
print(treg2)
print("clcomp=" + str(clcomp))
clcomp += xallcomprv[i2] - csaux
for i2 in range(n):
if i2 not in auxvis:
if rcase_vis[i2]:
rcur2.remove_present(xaux[i2].added_suffix("_R" + str(i2)))
else:
for v in xaux[i2]:
w = auxlist[csaux.varlist.index(v.varlist[0])]
if w is not None:
rcur2.substitute_aux(Comp.rv(v.get_name() + "_R" + str(i2)), w)
for j in range(len(csaux)):
if auxlist[j] is not None:
rcur2.substitute_aux(csaux[j], auxlist[j])
#print("SUB " + "; ".join([str(v) + ":" + str(w) for v, w in oneaux]))
if verbose_op_detail2:
print("========= #" + iutil.strpad(str(i), 3, " after rename ===="))
print(rcur2)
#print(rcur2.getaux())
#print(rcur2.getaux().get_markers())
#if i == 4:
# return None
hint_aux_add = [(csaux[i], auxlist[i]) for i in range(len(csaux)) if auxlist[i] is not None]
hint_aux_avoid_add = [(csaux[i], csallcomprv - auxlist[i]) for i in range(len(csaux)) if auxlist[i] is not None]
#print(rcur2)
cdepgraph = [a[:] for a in depgraph]
for j in range(len(csaux)):
mask = auxmasks[j]
if mask is not None:
for j2 in range(len(csauxiall_is)):
if mask & (1 << j2) != 0:
cdepgraph[j][csauxiall_is[j2]] = True
for j in range(len(csaux)):
if auxlist[j] is None and aux.ispresent(csaux[j]):
cmask = 0
for j2 in range(len(csauxiall_is)):
tdepgraph = [a[:] for a in cdepgraph]
tdepgraph[j][csauxiall_is[j2]] = True
if iutil.iscyclic(tdepgraph):
cmask |= 1 << j2
if cmask != 0:
hint_aux_avoid_add.append((csaux[j], csauxiall.from_mask(cmask)))
if verbose_op_detail2:
print("AVOID " + str(csaux[j]) + " : " + str(csauxiall.from_mask(cmask)))
#rcaseallcomprv = rcase.allcomprv() + csnonaux
rcaseallcomprv = rcur2.imp_flippedonly().allcomprv() + csnonaux
#rcaseallcomprv = csnonaux
creg_indep = None
t_cur_leaveone = cur_leaveone
if numupdate >= max_numupdate:
did = True
t_cur_leaveone = False
oproof = None
if write_pf_enabled_inner:
oproof = PsiOpts.get_proof().copy()
curtermdid = False
for rr in rcur2.check_getaux_inplace_gen(must_include = must_include + mustleft,
single_include = single_include, hint_pair = hint_pair,
hint_aux = hint_aux + hint_aux_add,
hint_aux_avoid = hint_aux_avoid + hint_aux_avoid_add,
max_iter = cur_max_iter, leaveone = t_cur_leaveone):
cur_yield += 1
if cur_yield > cur_max_yield:
did = True
break
stype = iutil.signal_type(rr)
if stype == "":
if len(cons) > 0 and iutil.list_iscomplex(rr):
continue
t_multiuse = cur_multiuse
if t_multiuse and len(cons) > 0 and not csauxidep.isempty() and any(csauxidep.ispresent(w) for v, w in rr):
t_multiuse = False
if t_multiuse:
if creg_indep is None:
creg_indep = RegionOp.auxs_icreg(cur_auxs_incomp, rcaseallcomprv - allcomprv, rcaseallcomprv + reqcomprv)
#creg_indep.simplify()
#creg_indep = RegionOp.auxs_icreg(cur_auxs_incomp, clcomp - allcomprv, clcomp)
#cauxi = auxi - rcaseallcomprv
#ccond = allcomprv - auxi
#ccompleft = rcaseallcomprv - cauxi - ccond
#creg_indep = Region.Ic(cauxi, ccompleft, ccond)
#for v in aux:
# creg_indep.substitute(v, Comp.rv(str(v) + "_R" + str(i)))
#print("CREG")
#print(rcur2)
#print(creg_indep)
rcases_toadd2 = rcases_toadd
rcases_toadd = MHashSet()
for rcase_toadd_tuple in rcases_toadd2:
rcase_toadd = rcase_toadd_tuple[0]
rcase_toadd_vis = rcase_toadd_tuple[1]
for con in cons:
ccon = con.copy()
ccon.iand_norename(creg_indep)
Comp.substitute_list(ccon, rr, suffix = "_R" + str(i))
crcase = rcase_toadd.copy()
#crcase.iand_norename(ccon)
#crcase.simplify_quick(zero_group = 1)
crcase.iand_simplify_quick(ccon)
rcases_toadd.add((crcase, rcase_toadd_vis[:]))
if rcases_toadd != rcases_toadd2:
curtermdid = True
tauxlist = [(None if w is None else w.copy()) for w in auxlist]
for v, w in rr:
#print(">>>>" + str(csaux) + " " + str(v) + " " + str(w))
tauxlist[csaux.varlist.index(v.varlist[0])] = w.copy()
rr2 = [(csaux[j], tauxlist[j]) for j in range(len(csaux)) if tauxlist[j] is not None]
if len(rr2) > 0:
res.add(rr2)
if verbose_op_step:
print("ADD " + " ".join([str(v) + ":" + str(w) for v, w in rr2]) + " y=" + str(cur_yield) + "/" + str(cur_max_yield))
else:
oneuse_added[i] = True
#oneuse_aux.clear()
rcases_toadd2 = rcases_toadd
rcases_toadd = MHashSet()
for rcase_toadd_tuple in rcases_toadd2:
rcase_toadd = rcase_toadd_tuple[0]
rcase_toadd_vis = rcase_toadd_tuple[1]
if rcase_toadd_vis[i]:
rcases_toadd.add(rcase_toadd_tuple)
else:
for con in cons:
crcase = rcase_toadd.copy()
#crcase.iand_norename(con)
#crcase.simplify_quick(zero_group = 1)
crcase.iand_simplify_quick(con)
rcases_toadd.add((crcase, [rcase_toadd_vis[i2] or i2 == i for i2 in range(n)]))
tauxmasks = auxmasks[:]
for v, w in rr:
tauxmasks[csaux.varlist.index(v.varlist[0])] = csauxiall_id.get_mask(w)
#print("; ".join(str(v) + ":" + str(w) for v, w in rr))
#print(cdepends)
if oneuse_set.add((tauxmasks, auxvis + [i])):
did = True
curtermdid = True
if verbose_op_step:
tauxlist = [(None if w is None else w.copy()) for w in auxlist]
for v, w in rr:
tauxlist[csaux.varlist.index(v.varlist[0])] = w.copy()
rr2 = [(csaux[j], tauxlist[j]) for j in range(len(csaux)) if tauxlist[j] is not None]
# if len(cons) == 0:
# if len(rr2) > 0:
# res.add(rr2)
print("ONE " + " ".join([str(v) + ":" + str(w) for v, w in rr2]))
if len(cons) == 0:
break
elif stype == "leaveone":
if len(cons) == 0:
rcases_toadd2 = rcases_toadd
rcases_toadd = MHashSet()
casedid = False
for rcase_toadd_tuple in rcases_toadd2:
rcase_toadd = rcase_toadd_tuple[0]
rcase_toadd_vis = rcase_toadd_tuple[1]
#crcase = rcase_toadd & (rr[2] >= 0)
#crcase.simplify_quick(zero_group = 1)
crcase = rcase_toadd.copy()
if not crcase.implies_ineq_quick(rr[2], ">="):
crcase.iand_simplify_quick((rr[2] >= 0).add_meta("pf_note", ["case"]))
casedid = True
rcases_toadd.add((crcase, rcase_toadd_vis))
if casedid and rcases_toadd != rcases_toadd2:
curtermdid = True
tauxlist = [(None if w is None else w.copy()) for w in auxlist]
for v, w in rr[1]:
tauxlist[csaux.varlist.index(v.varlist[0])] = w.copy()
rr2 = [(csaux[j], tauxlist[j]) for j in range(len(csaux)) if tauxlist[j] is not None]
if len(rr2) > 0:
res.add(rr2)
if verbose_op_step:
print("LVO " + " ".join([str(v) + ":" + str(w) for v, w in rr2]))
numupdate += 1
if numupdate >= max_numupdate:
did = True
break
elif stype == "max_iter_reached":
did = True
if len(rcases_toadd) > casesteplimit:
break
if len(rcases_toadd) == 0:
break
if not curtermdid and oproof is not None:
PsiOpts.set_proof(oproof)
if len(rcases_toadd) > casesteplimit:
break
if len(rcases_toadd) == 0:
break
if cur_yield > cur_max_yield:
did = True
break
rcases += rcases_toadd
if verbose_op_detail and i < n - 1:
print("========= cases ========")
print("\nOR\n".join(str(rcase[0]) for rcase in rcases))
if cnstep == 1 and get_info is not None:
consonly_reg = RegionOp.union([a for a, _ in rcases], tosimple = True)
if cnstep != 1 and not nonsimple_did:
break
# if not nonsimple_did:
# break
rcases_hash = hash(rcases)
if rcases_hash not in rcases_hashset:
did = True
rcases_hashset.add(rcases_hash)
if verbose_op_step:
print("========= cases ========")
print("\nOR\n".join(str(rcase[0]) for rcase in rcases))
if len(rcases) == 0:
break
cur_max_iter = int(cur_max_iter * max_iter_pow)
cur_max_yield = int(cur_max_yield * max_yield_pow + max_yield_add)
if get_info is not None:
for i in range(len(get_info)):
if get_info[i] == "assumption_init":
get_info[i] = consonly_reg.copy()
get_info[i] = get_info[i].exists(csauxi.inter(get_info[i].allcomprv()))
elif get_info[i] == "assumption_final":
get_info[i] = RegionOp.union([a for a, _ in rcases], tosimple = True)
get_info[i] = get_info[i].exists(csauxi.inter(get_info[i].allcomprv()))
elif get_info[i] == "assumption_new":
if consonly_reg.get_type() == RegionType.NORMAL:
bnet = consonly_reg.get_bayesnet()
tlist = []
for ccase, _ in rcases:
tr = ccase.copy()
tr.simplify_redundant(reg = consonly_reg, quick = True, bnet = bnet)
tr.simplify(reg = consonly_reg)
tlist.append(tr)
get_info[i] = RegionOp.union(list(tlist), tosimple = True)
get_info[i] = get_info[i].exists(csauxi.inter(get_info[i].allcomprv()))
else:
get_info[i] = Region.universe()
PsiOpts.settings["proof_enabled"] = write_pf_enabled
if len(rcases) == 0:
if verbose_op:
print("========= success ========")
print(iutil.list_tostr_std(res.x))
resrr = None
#resrr = res.x
if len(res.x) == 1:
resrr = res.x[0]
else:
resrr = res.x
if write_pf_enabled:
if write_pf_twopass:
if verbose_op:
print("Proof consonly:")
print(self.consonly())
print()
print(self.consonly().simplified_quick())
print()
pf = ProofObj.from_region(self if write_pf_repeat_claim else self.consonly().simplified_quick(), c = "Claim:")
PsiOpts.set_setting(proof_step_in = pf)
resrr_dict = Comp.substitute_list_to_dict(resrr, multi = True)
if not Comp.substitute_dict_ismulti(resrr_dict):
pf = ProofObj.from_region(None, c = ["Substitute ", CompArray(resrr_dict).add_meta("omit_bracket", True).add_meta("subs", True), ":"])
PsiOpts.set_setting(proof_add = pf)
# cs = self.copy()
cs = RegionOp.union([])
for i in range(n):
req = xreq[i]
cons = xcons[i]
tadd = (req & ~RegionOp.union(cons)).noaux()
for v in xaux[i]:
tadd.substitute_aux(Comp.rv(v.get_name() + "_R" + str(i)), v)
tadd = tadd.substituted_dict_union_plain(resrr_dict)
cs |= tadd
cs = cs.copy()
# Comp.substitute_list(cs, resrr, isaux = True)
cs = cs.noaux()
if verbose_op:
print("Proof reconstruction:")
print(cs)
print()
if cs.getaux().isempty():
with PsiOpts(proof_enabled = True):
cs.check()
PsiOpts.set_setting(proof_step_out = True)
return resrr
return None
def check_getaux_op_inplace(self, hint_pair = None, hint_aux = None):
"""Return whether implication is true, with auxiliary search result."""
r = []
#print("")
#print(self)
for x in self.regs:
#print(x)
if x.get_type() == RegionType.NORMAL:
t = x.check_getaux(hint_pair, hint_aux)
else:
t = x.check_getaux_op_inplace(hint_pair, hint_aux)
if self.get_type() == RegionType.UNION and t is not None:
return t
if self.get_type() == RegionType.INTER and t is None:
return None
r.append(t)
if self.get_type() == RegionType.INTER:
return r
return None
def check_getaux(self, hint_pair = None, hint_aux = None):
"""Return whether implication is true, with auxiliary search result."""
truth = PsiOpts.settings["truth"]
if truth is not None:
with PsiOpts(truth = None):
return (truth >> self).check_getaux(hint_pair, hint_aux)
indreg = self.get_indreg_checked()
if indreg is not None:
with PsiOpts(indreg_enabled = False):
return (indreg >> self).check_getaux(hint_pair, hint_aux)
cs = self.copy()
return cs.check_getaux_inplace(hint_pair = hint_pair, hint_aux = hint_aux)
def check_getaux_gen(self, hint_pair = None, hint_aux = None, hint_aux_avoid = None):
"""Return whether implication is true, with auxiliary search result."""
rr = self.check_getaux(hint_pair = hint_pair, hint_aux = hint_aux)
if rr is not None:
yield rr
def check(self):
"""Return whether implication is true."""
if iutil.get_solver() == "z3":
return self.check_z3()
return self.check_getaux() is not None
def assumption(self, mode = None):
"""Retrieve the strengthened assumptions for proving this region. The implication in this
region must be true if this assumption does not hold.
"""
cmode = "assumption_final"
if mode == "new":
cmode = "assumption_new"
elif mode == "init":
cmode = "assumption_init"
get_info = [cmode]
cs = self.copy()
with PsiOpts(cases = True):
cs.check_getaux_inplace(get_info = get_info)
return get_info[0]
def evalcheck(self, f):
truth = PsiOpts.settings["truth"]
if truth is not None:
with PsiOpts(truth = None):
return (truth >> self).evalcheck(f)
indreg = self.get_indreg_checked()
if indreg is not None:
with PsiOpts(indreg_enabled = False):
return (indreg >> self).evalcheck(f)
ceps = PsiOpts.settings["eps_check"]
isunion = (self.get_type() == RegionType.UNION)
for x, c in self.regs:
if isunion ^ c ^ x.evalcheck(f):
return isunion
return not isunion
def evalcheck_vec(self, f):
truth = PsiOpts.settings["truth"]
if truth is not None:
with PsiOpts(truth = None):
return (truth >> self).evalcheck_vec(f)
indreg = self.get_indreg_checked()
if indreg is not None:
with PsiOpts(indreg_enabled = False):
return (indreg >> self).evalcheck_vec(f)
ceps = PsiOpts.settings["eps_check"]
isunion = (self.get_type() == RegionType.UNION)
r = 1
for x, c in self.regs:
t = x.evalcheck_vec(f)
if not isunion ^ c:
t = 1 - t
r = r * t
if isunion:
r = 1 - r
return r
def eval_max_violate(self, f):
truth = PsiOpts.settings["truth"]
if truth is not None:
with PsiOpts(truth = None):
return (truth >> self).eval_max_violate(f)
indreg = self.get_indreg_checked()
if indreg is not None:
with PsiOpts(indreg_enabled = False):
return (indreg >> self).eval_max_violate(f)
ceps = PsiOpts.settings["eps_check"]
if self.get_type() == RegionType.INTER:
r = 0.0
for x, c in self.regs:
t = x.eval_max_violate(f)
if c:
r = max(r, t)
else:
if t <= ceps:
return numpy.inf
return r
elif self.get_type() == RegionType.UNION:
r = numpy.inf
for x, c in self.regs:
t = x.eval_max_violate(f)
if c:
r = min(r, t)
else:
if t > ceps:
return 0.0
return r
return 0.0
def eval_sum_violate(self, f, pow = 1, leak = 0.1):
# truth = PsiOpts.settings["truth"]
# if truth is not None:
# with PsiOpts(truth = None):
# return (truth >> self).eval_sum_violate(f, pow = pow, leak = leak)
ceps = PsiOpts.settings["eps_check"]
if self.get_type() == RegionType.INTER:
r = 0.0
for x, c in self.regs:
t = x.eval_sum_violate(f, pow = pow, leak = leak)
if c:
r = r + t
else:
if t <= ceps:
return numpy.inf
return r
elif self.get_type() == RegionType.UNION:
r = numpy.inf
for x, c in self.regs:
t = x.eval_sum_violate(f, pow = pow, leak = leak)
if c:
# r = r * t
r = min(r, t)
else:
if t > ceps:
return 0.0
return r
return 0.0
def implies_getaux(self, other, hint_pair = None, hint_aux = None):
"""Whether self implies other, with auxiliary search result."""
return (self <= other).check_getaux(hint_pair, hint_aux)
def istight(self, canon = False):
return all(x.istight(canon) for x, c in self.regs)
def tighten(self):
for x, c in self.regs:
x.tighten()
def add_meta_present(self, b, key, value):
for x, c in self.regs:
x.add_meta_present(b, key, value)
def simplify_op(self):
if self.isuniverse():
self.setuniverse()
return self
if self.isempty():
self.setempty()
return self
if self.get_type() == RegionType.INTER or self.get_type() == RegionType.UNION:
for i in range(len(self.regs)):
x, c = self.regs[i]
if not c and x.get_type() == RegionType.NORMAL:
t = x.try_negate(eps_only = True)
if t is not None:
self.regs[i] = (t, not c)
if len(self.auxs) == 0:
tregs = []
for x, c in self.regs:
if x.isuniverse(c ^ (self.get_type() == RegionType.UNION)):
continue
if c and x.get_type() == self.get_type() and len(x.auxs) == 0:
tregs += x.regs
self.auxs = x.auxs + self.auxs
else:
tregs.append((x, c))
self.regs = tregs
self.aux_clean()
return self
def simplify_quick(self, reg = None, zero_group = 0):
"""Simplify a region in place, without linear programming.
Optional argument reg with constraints assumed to be true.
zero_group = 2: group all nonnegative terms as a single inequality.
"""
if not PsiOpts.settings.get("simplify_enabled", False):
return self
#self.distribute()
#self.remove_missing_aux()
for x, c in self.regs:
x.simplify_quick(reg, zero_group)
self.simplify_op()
return self
def aux_push(self):
for i in range(len(self.regs)):
for x, c in self.auxs:
if c:
self.regs[i] = (self.regs[i][0].exists(x), self.regs[i][1])
else:
self.regs[i] = (self.regs[i][0].forall(x), self.regs[i][1])
self.auxs = []
return self
def simplify_union(self, reg = None):
"""Simplify a union region in place. May take much longer than Region.simplify().
Optional argument reg with constraints assumed to be true.
"""
if reg is None:
reg = Region.universe()
# print("simplify_union pre distribute")
# print(self)
self.distribute()
# print("simplify_union post distribute")
# print(self)
if not self.get_type() == RegionType.UNION:
return
if any(not c for x, c in self.auxs):
return
#self.aux_push()
aux = Comp.empty()
for x, c in self.auxs:
aux += x
regc = [i for i in range(len(self.regs)) if self.regs[i][1]]
for i in regc:
self.regs[i][0].eliminate(aux)
self.regs[i][0].simplify()
self.regs[i][0].remove_aux(aux)
regs_rem = [False for x, c in self.regs]
for i, j in itertools.permutations(regc, 2):
if regs_rem[i] or regs_rem[j]:
continue
#print("###")
if (self.regs[i][0].exists(aux.inter(self.regs[i][0].allcomprv()))
& reg).implies(self.regs[j][0].exists(aux.inter(self.regs[j][0].allcomprv()))):
regs_rem[i] = True
self.regs = [(x, c) for i, (x, c) in enumerate(self.regs) if not regs_rem[i]]
# print("simplify_union output")
# print(self)
return self
def imp_present(self):
return True
def var_mi_only(self, v):
return all(x.var_mi_only(v) for x, c in self.regs)
def sort(self):
pass
def simplify(self, reg = None, zero_group = 0, **kwargs):
"""Simplify a region in place.
Optional argument reg with constraints assumed to be true.
zero_group = 2: group all nonnegative terms as a single inequality.
"""
if kwargs:
r = None
with PsiOpts(**{"simplify_" + key: val for key, val in kwargs.items()}):
r = self.simplify(reg, zero_group)
return r
if not PsiOpts.settings.get("simplify_enabled", False):
return self
simplify_redundant_op = (PsiOpts.settings.get("simplify_redundant_op", False)
and not PsiOpts.settings.get("simplify_quick", False))
#self.distribute()
self.remove_missing_aux()
r_assumed = None
if reg is None:
r_assumed = Region.universe()
else:
r_assumed = reg.copy()
if PsiOpts.settings.get("simplify_regterm", False):
self.simplify_regterm(reg)
with PsiOpts(simplify_regterm = False):
if self.get_type() == RegionType.INTER or self.get_type() == RegionType.UNION:
isunion = (self.get_type() == RegionType.UNION)
regs_sorted = list(self.regs)
def reg_sort_priority(t):
x, c = t
return (c == (not isunion), x.complexity())
regs_sorted.sort(key = reg_sort_priority)
for i0, (x0, c0) in enumerate(regs_sorted):
t_assumed = r_assumed.copy()
for i, (x, c) in enumerate(regs_sorted):
if i0 == i:
continue
if c ^ isunion:
t = x.tosimple_noaux()
if t is not None:
t_assumed &= t
else:
flipped = x.tosimple_noaux()
if flipped is not None:
flipped = flipped.one_flipped()
if flipped is not None:
t_assumed &= flipped
x0.simplify(t_assumed, zero_group)
# def reg_sort_priority(t):
# x, c = t
# if c == (not isunion):
# return 0
# if x.one_flipped() is not None:
# return 1
# return 2
# regs_sorted.sort(key = reg_sort_priority)
# for x, c in regs_sorted:
# x.simplify(r_assumed, zero_group)
# if c ^ isunion:
# t = x.tosimple_noaux()
# if t is not None:
# r_assumed &= t
# else:
# flipped = x.tosimple_noaux()
# if flipped is not None:
# flipped = flipped.one_flipped()
# if flipped is not None:
# r_assumed &= flipped
# for c_pass, one_flipped_need in [(not isunion, None), (isunion, True), (isunion, False)]:
# for x, c in self.regs:
# if c != c_pass:
# continue
# if one_flipped_needed is not None:
# if c == c_pass and (one_flipped_need is None or ):
# x.simplify(r_assumed, zero_group)
# if c ^ isunion:
# t = x.tosimple_noaux()
# if t is not None:
# r_assumed &= t
if simplify_redundant_op:
regs_s = [(x.tosimple_noaux() if not c ^ isunion else None) for x, c in self.regs]
regs_rem = [False for x, c in self.regs]
# for i, j in itertools.permutations([i for i in range(len(self.regs)) if regs_s[i] is not None], 2):
for i in range(len(self.regs)):
if PsiOpts.is_timer_ended():
break
if regs_rem[i] or regs_s[i] is None:
continue
#print("###")
t_assumed = r_assumed.copy()
for k, (x, c) in enumerate(self.regs):
if k == i or regs_rem[k]:
continue
if c ^ isunion:
t = x.tosimple_noaux()
if t is not None:
t_assumed &= t
else:
flipped = x.tosimple_noaux()
if flipped is not None:
flipped = flipped.one_flipped()
if flipped is not None:
t_assumed &= flipped
# for k in range(len(self.regs)):
# if k != i and not regs_rem[k] and not self.regs[k][1] ^ isunion:
# tf = self.regs[k][0].one_flipped()
# if tf is not None:
# t_assumed &= tf
for j in range(len(self.regs)):
if PsiOpts.is_timer_ended():
break
if i == j or regs_rem[j] or regs_s[j] is None:
continue
if (regs_s[i] & t_assumed).implies(regs_s[j]):
regs_rem[i] = True
break
self.regs = [(x, c) for i, (x, c) in enumerate(self.regs) if not regs_rem[i]]
self.simplify_op()
if PsiOpts.settings.get("simplify_union", False):
self.simplify_union(reg)
return self
def simplified_quick(self, reg = None, zero_group = 0):
"""Returns the simplified region
Optional argument reg with constraints assumed to be true
zero_group = 2: group all nonnegative terms as a single inequality
"""
if reg is None:
reg = Region.universe()
r = self.copy()
r.simplify_quick(reg, zero_group)
t = r.tosimple_safe()
if t is not None:
return t
return r
def simplified(self, reg = None, zero_group = 0, **kwargs):
"""Returns the simplified region
Optional argument reg with constraints assumed to be true
zero_group = 2: group all nonnegative terms as a single inequality
"""
if reg is None:
reg = Region.universe()
r = self.copy()
r.simplify(reg, zero_group, **kwargs)
t = r.tosimple_safe()
if t is not None:
return t
return r
def add_aux(self, aux, c):
if len(self.auxs) > 0 and self.auxs[-1][1] == c:
self.auxs[-1]= (self.auxs[-1][0] + aux, self.auxs[-1][1])
else:
self.auxs.append((aux.copy(), c))
def remove_aux(self, w):
t = self.auxs
self.auxs = []
for x, c in t:
y = x - w
if not y.isempty():
self.auxs.append((y, c))
def remove_missing_aux(self):
#return
t = self.auxs
self.auxs = []
allcomp = self.allcomprv()
for x, c in t:
y = x.inter(allcomp)
if not y.isempty():
self.auxs.append((y, c))
def eliminate(self, w, reg = None, toreal = False, forall = False, quick = False, method = "", reg_record = None):
w = Region.get_allcomp(w)
toelim = Comp.empty()
for v in w.allcomp():
if toreal or v.get_type() == IVarType.REAL:
toelim += v
elif v.get_type() == IVarType.RV:
self.add_aux(v, not forall)
simplify_needed = False
if not toelim.isempty():
simplify_needed = True
if forall:
self.negate()
# print(self)
self.distribute(force_split = True)
# print(self)
for x, c in self.regs:
x.simplify_quick()
if not c and x.ispresent(toelim):
if forall:
self.setempty()
else:
self.setuniverse()
return self
for x, c in self.regs:
if x.ispresent(toelim):
x.eliminate(toelim, reg = reg, toreal = toreal, forall = not c, quick = quick, method = method, reg_record = reg_record)
if forall:
self.negate()
if simplify_needed:
if quick:
self.simplify_quick(reg)
else:
self.simplify(reg)
return self.tonormal_safe()
return self
def eliminate_quick(self, w, reg = None, toreal = False, forall = False, method = ""):
return self.eliminate(w, reg = reg, toreal = toreal, forall = forall, quick = True, method = method)
def marginal_eliminate(self, w):
for x in self.regs:
x.marginal_eliminate(w)
def kernel_eliminate(self, w):
for x in self.regs:
x.kernel_eliminate(w)
def tostring(self, style = 0, tosort = False, lhsvar = "real", inden = 0, add_bracket = False,
small = False, skip_outer_exists = False):
"""Convert to string.
Parameters:
style : Style of string conversion
STR_STYLE_STANDARD : I(X,Y;Z|W)
STR_STYLE_PSITIP : I(X+Y&Z|W)
"""
style = iutil.convert_str_style(style)
cregs = self.regs
ctype = self.get_type()
if ctype == RegionType.UNION:
cregs0 = [(x, True) for x, c in cregs if not c]
cregs1 = [(x, True) for x, c in cregs if c]
if len(cregs0) == 1 and len(cregs1) == 1:
cregs = cregs0 + cregs1
ctype = -1000
if isinstance(lhsvar, str) and lhsvar == "real":
lhsvar = self.allcomprealvar()
curadd_bracket = True
if style & PsiOpts.STR_STYLE_LATEX:
if len(cregs) == 1:
curadd_bracket = add_bracket
r = ""
interstr = ""
nlstr = "\n"
notstr = "NOT"
spacestr = " "
if style & PsiOpts.STR_STYLE_PSITIP:
notstr = "~"
elif style & PsiOpts.STR_STYLE_LATEX:
notstr = "\\lnot"
if style & PsiOpts.STR_STYLE_LATEX_ARRAY:
nlstr = "\\\\\n"
spacestr = "\\;"
if ctype == RegionType.UNION:
if style & PsiOpts.STR_STYLE_PSITIP:
interstr = "|"
elif style & PsiOpts.STR_STYLE_LATEX:
interstr = PsiOpts.settings["latex_or"]
else:
interstr = "OR"
if ctype == RegionType.INTER:
if style & PsiOpts.STR_STYLE_PSITIP:
interstr = "&"
elif style & PsiOpts.STR_STYLE_LATEX:
interstr = PsiOpts.settings["latex_and"]
else:
interstr = "AND"
if ctype == -1000:
if style & PsiOpts.STR_STYLE_PSITIP:
interstr = ">>"
elif style & PsiOpts.STR_STYLE_LATEX:
interstr = PsiOpts.settings["latex_matimplies"]
else:
interstr = "IMPLIES"
if self.isuniverse(sgn = True, canon = True):
if style & PsiOpts.STR_STYLE_PSITIP:
return spacestr * inden + "RegionOp.universe()"
elif style & PsiOpts.STR_STYLE_LATEX:
return spacestr * inden + PsiOpts.settings["latex_region_universe"]
else:
return spacestr * inden + "Universe"
if self.isuniverse(sgn = False, canon = True):
if style & PsiOpts.STR_STYLE_PSITIP:
return spacestr * inden + "RegionOp.empty()"
elif style & PsiOpts.STR_STYLE_LATEX:
return spacestr * inden + PsiOpts.settings["latex_region_empty"]
else:
return spacestr * inden + "{}"
for x, c in reversed(self.auxs):
if c:
if style & PsiOpts.STR_STYLE_PSITIP:
pass
elif style & PsiOpts.STR_STYLE_LATEX:
if not style & PsiOpts.STR_STYLE_LATEX_QUANTAFTER:
r += PsiOpts.settings["latex_exists"] + " "
r += x.tostring(style = style, tosort = tosort)
r += PsiOpts.settings["latex_quantifier_sep"] + " "
else:
if style & PsiOpts.STR_STYLE_PSITIP:
pass
elif style & PsiOpts.STR_STYLE_LATEX:
if not style & PsiOpts.STR_STYLE_LATEX_QUANTAFTER:
r += PsiOpts.settings["latex_forall"] + " "
r += x.tostring(style = style, tosort = tosort)
r += PsiOpts.settings["latex_quantifier_sep"] + " "
inden_inner = inden
inden_inner1 = inden + 2
if style & PsiOpts.STR_STYLE_PSITIP:
r += spacestr * inden + "(" + nlstr
elif style & PsiOpts.STR_STYLE_LATEX:
if style & PsiOpts.STR_STYLE_LATEX_ARRAY:
r += spacestr * inden
if curadd_bracket:
r += "\\left\\{"
r += "\\begin{array}{l}\n"
inden_inner = 0
inden_inner1 = 0
else:
r += spacestr * inden + "\\{" + nlstr
else:
r += spacestr * inden + "{" + nlstr
rlist = [spacestr * inden_inner1 + ("" if c else " " + notstr) +
x.tostring(style = style, tosort = tosort, lhsvar = lhsvar, inden = inden_inner1,
add_bracket = True, small = small).lstrip()
for x, c in cregs]
if tosort:
rlist = zip(rlist, [any(x.ispresent(t) for t in lhsvar) for x, c in cregs])
rlist = sorted(rlist, key=lambda a: (not a[1], len(a[0]), a[0]))
rlist = [x for x, t in rlist]
r += (nlstr + spacestr * inden_inner + " " + interstr + nlstr).join(rlist)
r += nlstr + spacestr * inden_inner
if style & PsiOpts.STR_STYLE_PSITIP:
r += ")"
elif style & PsiOpts.STR_STYLE_LATEX:
if style & PsiOpts.STR_STYLE_LATEX_ARRAY:
r += "\\end{array}"
if curadd_bracket:
r += "\\right\\}"
else:
r += "\\}"
else:
r += "}"
for x, c in self.auxs:
if c:
if style & PsiOpts.STR_STYLE_PSITIP:
r += ".exists("
elif style & PsiOpts.STR_STYLE_LATEX:
if style & PsiOpts.STR_STYLE_LATEX_QUANTAFTER:
r += " , " + PsiOpts.settings["latex_exists"] + " "
else:
continue
else:
r += " , exists "
else:
if style & PsiOpts.STR_STYLE_PSITIP:
r += ".forall("
elif style & PsiOpts.STR_STYLE_LATEX:
if style & PsiOpts.STR_STYLE_LATEX_QUANTAFTER:
r += " , " + PsiOpts.settings["latex_forall"] + " "
else:
continue
else:
r += " , forall "
r += x.tostring(style = style, tosort = tosort)
if style & PsiOpts.STR_STYLE_PSITIP:
r += ")"
return r
def __hash__(self):
#return hash(self.tostring(tosort = True))
return hash((self.rtype,
hash(frozenset((hash(x), c) for x, c in self.regs)),
hash(tuple((hash(x), c) for x, c in self.auxs)),
hash(self.inp), hash(self.oup)
))
class MonotoneSet:
def __init__(self, sgn = 1):
self.sgn = sgn
self.cache = []
def add(self, x):
if self.sgn > 0:
self.cache = [y for y in self.cache if not y >= x]
elif self.sgn < 0:
self.cache = [y for y in self.cache if not y <= x]
self.cache.append(x)
def __contains__(self, x):
for i in range(len(self.cache)):
if (self.sgn > 0 and x >= self.cache[i]) or (self.sgn < 0 and x <= self.cache[i]):
for j in range(i, 0, -1):
self.cache[j - 1], self.cache[j] = self.cache[j], self.cache[j - 1]
return True
return False
def __len__(self):
return len(self.cache)
def __getitem__(self, key):
return self.cache[key]
class IBaseArray(IBaseObj):
def __init__(self, x = None, shape = None, meta = None):
if x is None:
x = []
if isinstance(x, dict):
d = x
x = []
for key, value in d.items():
if isinstance(value, list):
for t in value:
x.append([key, t])
else:
x.append([key, value])
cshape = None
if isinstance(x, type(self).entry_cls):
self.x = [x]
cshape = tuple()
elif iutil.istensor(x):
self.x = []
cshape = tuple(x.shape)
for xs in itertools.product(*[range(t) for t in cshape]):
self.x.append(self.entry_convert(x[xs]))
else:
self.x = []
tshape = []
def recur(i, y):
if isinstance(y, IBaseArray):
y = y.tolist()
if not isinstance(y, (list, tuple)):
self.x.append(self.entry_convert(y))
return
if len(tshape) <= i:
tshape.append(len(y))
else:
if len(y) != tshape[i]:
raise ValueError("Shape mismatch.")
return
for z in y:
recur(i + 1, z)
recur(0, x)
cshape = tuple(tshape)
if shape is not None:
self.shape = shape
else:
self.shape = cshape
if meta is None:
self.meta = dict()
else:
self.meta = meta
@property
def shape(self):
return self._shape
# if len(self._shape) == 0:
# return self._shape
# return self._shape + (len(self.x) // iutil.product(self._shape),)
@shape.setter
def shape(self, value):
if isinstance(value, int):
value = (value,)
for i in range(len(value)):
if value[i] < 0:
tvalue = list(value)
tvalue[i] = len(self.x) // (iutil.product(value[:i]) * iutil.product(value[i+1:]))
value = tuple(tvalue)
break
self._shape = tuple(value)
def reshaped(self, newshape):
r = self.copy()
r.shape = newshape
return r
def copy(self):
return type(self)([a.copy() for a in self.x], shape = self.shape, meta = iutil.copy(self.meta))
@classmethod
def empty(cls, shape = 0):
if isinstance(shape, int):
shape = (shape,)
n = iutil.product(shape)
return cls([cls.entry_cls_zero() for i in range(n)], shape = shape)
@classmethod
def zeros(cls, shape = 0):
return cls.empty(shape = shape)
@classmethod
def ones(cls, shape = 0, x = None):
if isinstance(shape, int):
shape = (shape,)
n = iutil.product(shape)
if x is None:
return cls([cls.entry_cls_one() for i in range(n)], shape = shape)
else:
return cls([x.copy() for i in range(n)], shape = shape)
@classmethod
def eye(cls, n, x = None):
r = cls.zeros((n, n))
for i in range(n):
if x is None:
r[i, i] = cls.entry_cls_one()
else:
r[i, i] = x.copy()
return r
@classmethod
def make(cls, *args):
r = cls.empty()
for a in args:
t = cls.entry_convert(a)
if t is not None:
r.append(t)
else:
for b in a:
r.append(b.copy())
return r
@classmethod
def isthis(cls, x):
if isinstance(x, cls):
return True
if isinstance(x, list) and len(x) > 0 and isinstance(x[0], cls.entry_cls):
return True
return False
def tolist(self, key = None, process = None):
if key is None:
key = tuple()
shape = self.shape
i = len(key)
if i == len(shape):
if process is not None:
return process(self[key])
else:
return self[key]
r = []
for j in range(shape[i]):
r.append(self.tolist(key + (j,), process = process))
return r
def to_numpy(self):
return numpy.array(self.tolist())
def to_dict(self):
shape = self.shape
if len(shape) < 2:
return dict()
r = dict()
for xs in itertools.product(*[range(t) for t in shape[:-1]]):
r[self[xs + (0,)]] = self[xs + (1,)]
return r
def find_dict(self, key):
shape = self.shape
if len(shape) < 2:
return None
r = None
for xs in itertools.product(*[range(t) for t in shape[:-1]]):
if self[xs + (0,)] == key:
r = self[xs + (1,)]
return r
def iadd_noduplicate(self, x):
nameset = set(a.get_name() for a in self.x)
for a in x:
cname = a.get_name()
if cname not in nameset:
nameset.add(cname)
self.append(a)
def series(self, vdir):
"""Get past or future sequence.
Parameters:
vdir : Direction, 1: future non-strict, 2: future strict,
-1: past non-strict, -2: past strict
"""
if vdir == 1:
return type(self)([sum(self.x[i:], type(self).entry_cls_zero()) for i in range(len(self.x))])
elif vdir == 2:
return type(self)([sum(self.x[i+1:], type(self).entry_cls_zero()) for i in range(len(self.x))])
elif vdir == -1:
return type(self)([sum(self.x[:i+1], type(self).entry_cls_zero()) for i in range(len(self.x))])
elif vdir == -2:
return type(self)([sum(self.x[:i], type(self).entry_cls_zero()) for i in range(len(self.x))])
return self.copy()
def past_ns(self):
return self.series(-1)
def past(self):
return self.series(-2)
def future_ns(self):
return self.series(1)
def future(self):
return self.series(2)
@property
def PN(self):
return self.series(-1)
@property
def P(self):
return self.series(-2)
@property
def FN(self):
return self.series(1)
@property
def F(self):
return self.series(2)
def subsets(self, ndim = 1, minsize = 0, maxsize = 100000, size = None, reverse = False):
"""Subsets of this array along the first ndim (default = 1) dimensions, as a generator
"""
shape = self.shape
t = []
for xs in itertools.product(*[range(t) for t in shape[:ndim]]):
t.append(self[tuple(xs)])
return igen.subset(t, minsize = minsize, maxsize = maxsize, size = size, reverse = reverse)
def allcomp(self):
return sum([a.allcomp() for a in self.x], Comp.empty())
def find(self, *args):
return self.allcomp().find(*args)
def from_mask(self, mask):
"""Return subset using bit mask."""
r = type(self).entry_cls_zero()
for i in range(len(self.x)):
if mask & (1 << i) != 0:
r += self.x[i]
return r
def append(self, a):
if len(self.shape) != 1:
raise ValueError("Can only append 1D array.")
return
self.x.append(a)
self.shape = (self.shape[0] + 1,)
def swapped_id(self, i, j):
if i >= len(self.x) or j >= len(self.x):
return self.copy()
r = self.copy()
r.x[i], r.x[j] = r.x[j], r.x[i]
return r
def transpose(self):
r = type(self).zeros(tuple(reversed(self.shape)))
for xs in itertools.product(*[range(t) for t in self.shape]):
r[tuple(reversed(xs))] = self[xs]
return r
@fcn_substitute
def substitute(self, v0, v1):
"""Substitute variable v0 by v1 (v1 can be compound), in place."""
for i in range(len(self.x)):
self.x[i].substitute(v0, v1)
return self
@fcn_substitute
def substitute_whole(self, v0, v1):
"""Substitute variable v0 by v1 (v1 can be compound), in place."""
for i in range(len(self.x)):
self.x[i].substitute_whole(v0, v1)
return self
@fcn_substitute
def substitute_aux(self, v0, v1):
"""Substitute variable v0 by v1 (v1 can be compound), and remove auxiliary v0, in place."""
for i in range(len(self.x)):
self.x[i].substitute_aux(v0, v1)
return self
def substituted_aux(self, *args, **kwargs):
"""Substitute variable v0 by v1 (v1 can be compound), and remove auxiliary v0, return result"""
r = self.copy()
r.substitute_aux(*args, **kwargs)
return r
def set_len(self, n):
if n < len(self.x):
self.x = self.x[:n]
return
while n > len(self.x):
self.x.append(type(self).entry_cls_zero())
def __neg__(self):
return type(self)([-a for a in self.x], shape = self.shape)
def __iadd__(self, other):
if isinstance(other, (tuple, list, ConcDist)):
other = type(self)(other)
if iutil.istensor(other):
if self.shape != other.shape:
raise ValueError("Shape mismatch.")
return
for xs in itertools.product(*[range(t) for t in self.shape]):
self[xs] += other[xs]
return self
# if isinstance(other, IBaseArray):
# r = []
# for i in range(len(other.x)):
# if i < len(self.x):
# self.x[i] += other.x[i]
# else:
# self.x.append(other.x[i].copy())
# return self
for i in range(len(self.x)):
self.x[i] += other
return self
def __imul__(self, other):
if isinstance(other, (tuple, list, ConcDist)):
other = type(self)(other)
if iutil.istensor(other):
if self.shape != other.shape:
raise ValueError("Shape mismatch.")
return
for xs in itertools.product(*[range(t) for t in self.shape]):
self[xs] *= other[xs]
return self
for i in range(len(self.x)):
self.x[i] *= other
return self
def __itruediv__(self, other):
if isinstance(other, (tuple, list, ConcDist)):
other = type(self)(other)
if iutil.istensor(other):
if self.shape != other.shape:
raise ValueError("Shape mismatch.")
return
for xs in itertools.product(*[range(t) for t in self.shape]):
self[xs] /= other[xs]
return self
for i in range(len(self.x)):
self.x[i] /= other
return self
def __ipow__(self, other):
if isinstance(other, (tuple, list, ConcDist)):
other = type(self)(other)
if iutil.istensor(other):
if self.shape != other.shape:
raise ValueError("Shape mismatch.")
return
for xs in itertools.product(*[range(t) for t in self.shape]):
self[xs] **= other[xs]
return self
for i in range(len(self.x)):
self.x[i] **= other
return self
def __mul__(self, other):
r = self.copy()
r *= other
return r
def __rmul__(self, other):
r = self.copy()
r *= other
return r
def __truediv__(self, other):
r = self.copy()
r /= other
return r
def __rtruediv__(self, other):
r = type(self).ones(len(self.x))
r *= other
r /= self
return r
def __pow__(self, other):
r = self.copy()
r **= other
return r
def __rpow__(self, other):
r = type(self).ones(len(self.x))
r *= other
r **= self
return r
def __add__(self, other):
if isinstance(other, int) and other == 0:
return self.copy()
if isinstance(other, (tuple, list, ConcDist)):
other = type(self)(other)
r = self.copy()
r += other
return r
def __radd__(self, other):
if isinstance(other, int) and other == 0:
return self.copy()
if isinstance(other, (tuple, list, ConcDist)):
other = type(self)(other)
r = self.copy()
r += other
return r
def __isub__(self, other):
if isinstance(other, (tuple, list, ConcDist)):
other = type(self)(other)
self += -other
return self
def __sub__(self, other):
if isinstance(other, (tuple, list, ConcDist)):
other = type(self)(other)
r = self.copy()
r += -other
return r
def __rsub__(self, other):
if isinstance(other, (tuple, list, ConcDist)):
other = type(self)(other)
r = -self
r += other
return r
def __len__(self):
return len(self.x)
def product(self):
if len(self.x) == 0:
return 1
r = None
for a in self.x:
if r is None:
r = a.copy()
else:
r *= a
return r
def key_to_id(self, key):
if isinstance(key, int):
key = (key,)
if isinstance(key, tuple) and all(isinstance(xs, int) for xs in key):
shape = self.shape
if len(key) != len(shape):
raise IndexError("Dimension mismatch.")
return
i = 0
for s, k in zip(shape, key):
if k < 0 or k >= s:
raise IndexError("Index out of bound.")
return
i = i * s + k
return i
return 0
def getitem_str(self, key):
key = key.lower()
r = type(self).entry_cls_zero()
for c in key:
if c == "0" or c == "c":
r += self.x[0]
elif c == "p" and len(self.x) > 1:
r += self.x[1]
elif c == "f" and len(self.x) > 2:
r += self.x[2]
elif c == "a":
if len(self.x) > 1:
r += self.x[1]
for i, y in enumerate(self.x):
if i != 1:
r += y
return r
def __getitem__(self, key):
if isinstance(key, int):
return self.x[key]
if isinstance(key, str):
return self.getitem_str(key)
if isinstance(key, IBaseObj):
return self.find_dict(key)
if isinstance(key, slice):
r = self.x[key]
if isinstance(r, list):
return type(self)(r)
return r
if isinstance(key, tuple) and len(key) < len(self.shape):
return type(self)(self.tolist(key))
return self.x[self.key_to_id(key)]
def __setitem__(self, key, item):
if isinstance(key, int):
self.x[key] = self.entry_convert(item)
self.x[self.key_to_id(key)] = self.entry_convert(item)
def dot(self, other):
""" Dot product like numpy.dot.
"""
if isinstance(other, list):
other = type(self)(other)
if not iutil.istensor(other):
return self * other
selfshape = self.shape
othershape = other.shape
self_np = max(len(selfshape) - 1, 0)
other_np = max(len(othershape) - 2, 0)
if selfshape[self_np] != othershape[other_np]:
raise ValueError("Shape mismatch.")
return
r = type(self).zeros(selfshape[:self_np] + othershape[:other_np] + othershape[other_np+1:])
for xs in itertools.product(*[range(t) for t in selfshape[:self_np]]):
for ys in itertools.product(*[range(t) for t in othershape[:other_np]]):
for i in range(selfshape[self_np]):
for zs in itertools.product(*[range(t) for t in othershape[other_np+1:]]):
r[xs + ys + zs] += self[xs + (i,)] * other[ys + (i,) + zs]
if len(r.shape) == 0:
return r[tuple()]
return r
def __matmul__(self, other):
return self.dot(other)
def __imatmul__(self, other):
return (type(self)(other)).dot(self)
def trace_mat(self):
"""Trace of matrix.
"""
selfshape = self.shape
r = type(self).zeros(selfshape[2:])
for i in range(min(selfshape[0], selfshape[1])):
for xs in itertools.product(*[range(t) for t in selfshape[2:]]):
r[xs] += self[(i, i) + xs]
if len(r.shape) == 0:
return r[tuple()]
return r
def trace(self):
"""Trace.
"""
selfshape = self.shape
n = min(selfshape)
return sum((self[(i,) * len(selfshape)] for i in range(n)), type(self).entry_cls_zero())
def diag(self):
"""Return diagonal.
"""
selfshape = self.shape
n = min(selfshape)
r = []
for i in range(n):
r.append(self[(i,) * len(selfshape)])
return type(self)(r)
def record_to(self, index):
for a in self.x:
a.record_to(index)
def isregtermpresent(self):
for a in self.x:
if a.isregtermpresent():
return True
return False
def get_sum(self):
"""Sum of all entries.
"""
return sum(self.x, type(self).entry_cls_zero())
def avg(self):
"""Average of all entries.
"""
return self.get_sum() / len(self)
# def add_meta(self, key, value):
# if self.meta is None:
# self.meta = {}
# self.meta[key] = value
# return self
# def get_meta(self, key):
# if self.meta is None:
# return None
# if key not in self.meta:
# return None
# return self.meta[key]
# def remove_meta(self, key):
# if self.meta is None:
# return self
# self.meta.pop(key, None)
# return self
def tostring(self, style = 0, tosort = False, *args, **kwargs):
"""Convert to string
Parameters:
style : Style of string conversion
STR_STYLE_STANDARD : I(X,Y;Z|W)
STR_STYLE_PSITIP : I(X+Y&Z|W)
"""
if len(self.shape) == 0:
return ""
style = iutil.convert_str_style(style)
is_subs = len(self) > 0 and not (style & PsiOpts.STR_STYLE_PSITIP) and self.get_meta("subs") is True
nlstr = "\n"
if style & PsiOpts.STR_STYLE_LATEX:
nlstr = "\\\\\n"
latex_hline = "\\hline\n"
shape = self.shape
r = ""
add_bracket = True
list_bracket0 = "["
list_bracket1 = "]"
interstr = ""
if style & PsiOpts.STR_STYLE_LATEX:
# list_bracket0 = ""
# list_bracket1 = ""
list_bracket0 = PsiOpts.settings["latex_list_bracket_l"]
list_bracket1 = PsiOpts.settings["latex_list_bracket_r"]
omit_bracket = len(self) > 0 and (self.get_meta("omit_bracket") is True)
if not (style & PsiOpts.STR_STYLE_PSITIP) and omit_bracket:
list_bracket0 = ""
list_bracket1 = ""
if is_subs:
if style & PsiOpts.STR_STYLE_LATEX:
interstr = PsiOpts.settings["latex_subs"]
if len(shape) >= 2 and any(tshape > 1 for tshape in shape[:-1]):
list_bracket0 = PsiOpts.settings["latex_subs_bracket_l"]
list_bracket1 = PsiOpts.settings["latex_subs_bracket_r"]
else:
list_bracket0 = ""
list_bracket1 = ""
else:
interstr = ":="
list_bracket0 = ""
list_bracket1 = ""
if style & PsiOpts.STR_STYLE_PSITIP:
r += type(self).cls_name + "("
if len(shape) > 1:
r += nlstr
add_bracket = False
isarray = False
if style & PsiOpts.STR_STYLE_LATEX:
if list_bracket0 != "":
r += "\\left" + list_bracket0 + " "
if not (len(self.shape) in [1, 2] and not any(tshape > 1 for tshape in shape[:-1])):
isarray = True
if is_subs and self.shape[-1] >= 2:
r += "\\begin{array}{" + "l" * self.shape[-1] + "}\n"
# r += "\\begin{array}{r" + "l" * (self.shape[-1] - 1) + "}\n"
else:
r += "\\begin{array}{" + "c" * self.shape[-1] + "}\n"
# if style & PsiOpts.STR_STYLE_PSITIP:
# r += type(self).cls_name + "([ "
# add_bracket = False
# elif style & PsiOpts.STR_STYLE_LATEX:
# if style & PsiOpts.STR_STYLE_LATEX_ARRAY:
# r += "\\left\\[\\begin{array}{l}\n"
# add_bracket = False
# else:
# r += "\\[ "
# else:
# r += "[ "
# float_style = self.get_meta("float_style")
# cont = PsiOpts()
# if float_style is not None:
# cont = PsiOpts(float_style = float_style)
# with cont:
for xs in itertools.product(*[range(t) for t in shape]):
si0 = len(shape)
while si0 > 0 and xs[si0 - 1] == 0:
si0 -= 1
si1 = len(shape)
while si1 > 0 and xs[si1 - 1] == shape[si1 - 1] - 1:
si1 -= 1
if si0 < len(shape) and not style & PsiOpts.STR_STYLE_LATEX:
r += " " * si0 + list_bracket0 * (len(shape) - si0)
if type(self).tostring_bracket_needed:
r += self[xs].tostring(style = style, tosort = tosort, add_bracket = add_bracket)
else:
r += self[xs].tostring(style = style, tosort = tosort)
if si1 < len(shape) and not style & PsiOpts.STR_STYLE_LATEX:
r += list_bracket1 * (len(shape) - si1)
if style & PsiOpts.STR_STYLE_LATEX:
if si1 > 0:
if si1 == len(shape):
if isarray:
r += " & "
elif is_subs:
r += " "
else:
r += r" \;\; "
if interstr != "":
r += interstr + " "
else:
if latex_hline is not None:
cc = len(shape) - si1
r += nlstr
if cc > 1:
r += latex_hline
if cc > 2:
r += nlstr * (cc - 2)
r += latex_hline
else:
r += nlstr * (len(shape) - si1)
else:
if si1 > 0:
r += ","
if si1 == len(shape):
r += " "
if interstr != "":
r += interstr + " "
else:
r += nlstr * (len(shape) - si1)
if style & PsiOpts.STR_STYLE_PSITIP:
r += ")"
if style & PsiOpts.STR_STYLE_LATEX:
if isarray:
r += "\\end{array}"
if list_bracket1 != "":
r += "\\right" + list_bracket1
return r
def __str__(self):
return self.tostring(PsiOpts.settings["str_style"], PsiOpts.settings["str_tosort"])
def __repr__(self):
return self.tostring(PsiOpts.settings["str_style_repr"])
@latex_postprocess
def _latex_(self):
return self.tostring(iutil.convert_str_style("latex"))
class CompArray(IBaseArray):
cls_name = "CompArray"
entry_cls = Comp
entry_cls_zero = Comp.empty
entry_cls_one = None
tostring_bracket_needed = True
@staticmethod
def entry_convert(a):
if isinstance(a, Comp):
return a.copy()
return None
def arg_convert(b):
if isinstance(b, list) or isinstance(b, Comp):
return CompArray.make(*b)
return b
def get_comp(self):
return sum(self.x, Comp.empty())
def get_term(self):
return Term([a.copy() for a in self.x])
# def series(self, vdir):
# """Get past or future sequence.
# Parameters:
# vdir : Direction, 1: future non-strict, 2: future strict,
# -1: past non-strict, -2: past strict
# """
# if vdir == 1:
# return CompArray([sum(self.x[i:], Comp.empty()) for i in range(len(self.x))])
# elif vdir == 2:
# return CompArray([sum(self.x[i+1:], Comp.empty()) for i in range(len(self.x))])
# elif vdir == -1:
# return CompArray([sum(self.x[:i+1], Comp.empty()) for i in range(len(self.x))])
# elif vdir == -2:
# return CompArray([sum(self.x[:i], Comp.empty()) for i in range(len(self.x))])
# return self.copy()
# def past_ns(self):
# return self.series(-1)
# def past(self):
# return self.series(-2)
# def future_ns(self):
# return self.series(1)
# def future(self):
# return self.series(2)
@property
def p(self):
return self.past()
@property
def f(self):
return self.future()
def series_list(self, name = None, suf0 = "Q", sufp = "P", suff = "F"):
if name is None:
if len(self.x) == 0:
name = ""
else:
name = self.x[0].get_name()
for a in self.x[1:]:
tname = a.get_name()
while not tname.startswith(name):
name = name[:-1]
r = []
if suf0 is not None:
r.append((Comp.rv(name + suf0), self.copy()))
if sufp is not None:
r.append((Comp.rv(name + sufp), self.past()))
if suff is not None:
r.append((Comp.rv(name + suff), self.future()))
return r
@staticmethod
def series_sym(x, sufp = "P", suff = "F"):
"""Series symbols (X, X_P, X_F), representing present, past and future.
"""
if isinstance(x, str):
x = Comp.rv(x)
r = CompArray.empty()
r.append(x)
rename_char = PsiOpts.settings["rename_char"]
if sufp is not None:
# r.append(Comp.rv(iutil.set_suffix_num(x.get_name(), sufp, rename_char, replace_mode = "append")))
r.append(x.name_operation(("append_sub", sufp)))
if suff is not None:
# r.append(Comp.rv(iutil.set_suffix_num(x.get_name(), suff, rename_char, replace_mode = "append")))
r.append(x.name_operation(("append_sub", suff)))
return r
def __and__(self, other):
if isinstance(other, CompArray) or isinstance(other, ExprArray):
return ExprArray([a & b for a, b in zip(self.x, other.x)], shape = self.shape)
return ExprArray([a & other for a in self.x], shape = self.shape)
def __or__(self, other):
if isinstance(other, CompArray) or isinstance(other, ExprArray):
return ExprArray([a | b for a, b in zip(self.x, other.x)], shape = self.shape)
return ExprArray([a | other for a in self.x], shape = self.shape)
def __rand__(self, other):
if isinstance(other, CompArray) or isinstance(other, ExprArray):
return ExprArray([b & a for a, b in zip(self.x, other.x)], shape = self.shape)
return ExprArray([other & a for a in self.x], shape = self.shape)
def __ror__(self, other):
if isinstance(other, CompArray) or isinstance(other, ExprArray):
return ExprArray([b | a for a, b in zip(self.x, other.x)], shape = self.shape)
return ExprArray([other | a for a in self.x], shape = self.shape)
def mark(self, *args, **kwargs):
for a in self.x:
a.mark(*args, **kwargs)
return self
def set_card(self, m):
for a in self.x:
a.set_card(m)
return self
def get_card(self):
return self.get_comp().get_card()
def get_shape(self):
r = []
for a in self.x:
t = a.get_card()
if t is None:
raise ValueError("Cardinality of " + str(a) + " not set. Use " + str(a) + ".set_card(m) to set cardinality.")
return
r.append(t)
return tuple(r)
class CheckResult(CompArray):
def __init__(self, *args, truth = None, method = None, reg = None, getaux = None, proof = None, model = None, display_reg = False, **kwargs):
self.truth = truth
self.method = method
self.reg = reg
self.getaux = getaux
self.proof = proof
self.model = model
self.display_reg = display_reg
CompArray.__init__(self, *args, **kwargs)
self.add_meta("omit_bracket", True)
self.add_meta("subs", True)
def tostring(self, style = 0, tosort = False, *args, **kwargs):
style = iutil.convert_str_style(style)
r = ""
nlstr = "\n"
if style & PsiOpts.STR_STYLE_LATEX:
nlstr = "\\\\\n"
if style & PsiOpts.STR_STYLE_LATEX:
r += "\\begin{array}{l}\n"
if self.display_reg and self.reg is not None:
r += self.reg.tostring(style)
if style & PsiOpts.STR_STYLE_LATEX:
r += "\\;\\mathrm{is}\\;"
else:
r += " is "
truthstr = ""
if self.truth is True:
truthstr = "True"
elif self.truth is False:
truthstr = "False"
else:
truthstr = "Unknown"
if style & PsiOpts.STR_STYLE_LATEX:
r += "\\mathrm{" + truthstr + "}"
else:
r += truthstr
r += nlstr
addproof = self.proof is not None
if not addproof and len(self) > 0:
r += CompArray.tostring(self, style = style, tosort = tosort, *args, **kwargs)
r += nlstr
if self.model is not None:
r += nlstr
r += self.model.tostring(style = style)
r += nlstr
if style & PsiOpts.STR_STYLE_LATEX:
r += "\\end{array}"
if self.proof is not None and not self.proof.isempty():
r += nlstr
r = iutil.latex_concat(style, [r, self.proof.tostring(style)])
return r
def __getitem__(self, key):
if len(self):
return CompArray.__getitem__(self, key)
if self.model is not None:
return self.model[key]
return None
def __bool__(self):
return bool(self.truth)
class ExprArray(IBaseArray):
cls_name = "ExprArray"
entry_cls = Expr
entry_cls_zero = Expr.zero
entry_cls_one = Expr.one
tostring_bracket_needed = False
@staticmethod
def entry_convert(a):
if isinstance(a, Expr):
return a.copy()
elif isinstance(a, Term):
return a.copy()
elif isinstance(a, (int, float)):
return Expr.const(a)
elif iutil.istensor(a) and len(a.shape) == 0:
return Expr.const(float(a))
return None
def set_float(self, force_float = True):
for x in self:
x.add_meta("float_style", force_float)
def get_expr(self):
return sum(self.x, Expr.zero())
# def series(self, vdir):
# """Get past or future sequence.
# Parameters:
# vdir : Direction, 1: future non-strict, 2: future strict,
# -1: past non-strict, -2: past strict
# """
# if vdir == 1:
# return ExprArray([sum(self.x[i:], Expr.zero()) for i in range(len(self.x))])
# elif vdir == 2:
# return ExprArray([sum(self.x[i+1:], Expr.zero()) for i in range(len(self.x))])
# elif vdir == -1:
# return ExprArray([sum(self.x[:i+1], Expr.zero()) for i in range(len(self.x))])
# elif vdir == -2:
# return ExprArray([sum(self.x[:i], Expr.zero()) for i in range(len(self.x))])
# return self.copy()
# def past_ns(self):
# return self.series(-1)
# def past(self):
# return self.series(-2)
# def future_ns(self):
# return self.series(1)
# def future(self):
# return self.series(2)
def isconst(self):
return all(a.get_const() is not None for a in self.x)
def to_numpy(self):
if self.isconst():
return numpy.array(self.tolist(process = lambda a: a.get_const()))
return IBaseArray.to_numpy(self)
def __abs__(self):
return eabs(self)
def __and__(self, other):
if isinstance(other, CompArray) or isinstance(other, ExprArray):
return ExprArray([a & b for a, b in zip(self.x, other.x)], shape = self.shape)
return ExprArray([a & other for a in self.x], shape = self.shape)
def __or__(self, other):
if isinstance(other, CompArray) or isinstance(other, ExprArray):
return ExprArray([a | b for a, b in zip(self.x, other.x)], shape = self.shape)
return ExprArray([a | other for a in self.x], shape = self.shape)
def __rand__(self, other):
if isinstance(other, CompArray) or isinstance(other, ExprArray):
return ExprArray([b & a for a, b in zip(self.x, other.x)], shape = self.shape)
return ExprArray([other & a for a in self.x], shape = self.shape)
def __ror__(self, other):
if isinstance(other, CompArray) or isinstance(other, ExprArray):
return ExprArray([b | a for a, b in zip(self.x, other.x)], shape = self.shape)
return ExprArray([other | a for a in self.x], shape = self.shape)
def ge_region(self):
r = Region.universe()
for a in self.x:
r.exprs_ge.append(a.copy())
return r
def eq_region(self):
r = Region.universe()
for a in self.x:
r.exprs_eq.append(a.copy())
return r
def __ge__(self, other):
# if not isinstance(other, ExprArray):
# other = ExprArray(other)
return (self - other).ge_region()
def __le__(self, other):
# if not isinstance(other, ExprArray):
# other = ExprArray(other)
return (other - self).ge_region()
def agree_shape_in(self, other):
shape_in = self.meta.get("prob_shape_in", None)
if shape_in is None:
return None
if isinstance(other, ExprArray):
other_shape_in = other.meta.get("prob_shape_in", None)
if other_shape_in is None or other_shape_in != shape_in:
return None
return shape_in
def __eq__(self, other):
# if not isinstance(other, ExprArray):
# other = ExprArray(other)
return (self - other).eq_region()
def discover_ic_lex(self, x):
"""Discover conditional independence relations among random variables in x using this entropy vector.
The entropy vector must be ordered in lexicographical order (e.g. obtained by ent_vector_lex(x)).
"""
return Region.ent_vector_discover_ic(self, x)
def discover_lex(self, x):
"""Discover relations among random variables in x using this entropy vector.
The entropy vector must be ordered in lexicographical order (e.g. obtained by ent_vector_lex(x)).
"""
return (self.discover_ic_lex(x) & (ent_vector_lex(*x) == self)).simplified()
def fcn(fcncall, shape):
if isinstance(shape, int):
shape = (shape,)
return ExprArray([Expr.fcn(
(lambda txs: (lambda P: fcncall(P, txs[0] if len(txs) == 1 else txs)))(xs))
for xs in itertools.product(*[range(t) for t in shape])],
shape = shape)
def __invert__(self):
return self == 0
# def tostring(self, style = 0, tosort = False):
# """Convert to string
# Parameters:
# style : Style of string conversion
# STR_STYLE_STANDARD : I(X,Y;Z|W)
# STR_STYLE_PSITIP : I(X+Y&Z|W)
# """
# style = iutil.convert_str_style(style)
# nlstr = "\n"
# if style & PsiOpts.STR_STYLE_LATEX:
# nlstr = "\\\\\n"
# r = ""
# add_bracket = True
# if style & PsiOpts.STR_STYLE_PSITIP:
# r += "ExprArray([ "
# add_bracket = False
# elif style & PsiOpts.STR_STYLE_LATEX:
# if style & PsiOpts.STR_STYLE_LATEX_ARRAY:
# r += "\\left\\[\\begin{array}{l}\n"
# add_bracket = False
# else:
# r += "\\[ "
# else:
# r += "[ "
# for i, a in enumerate(self.x):
# if i:
# if style & PsiOpts.STR_STYLE_LATEX:
# if style & PsiOpts.STR_STYLE_LATEX_ARRAY:
# r += nlstr
# else:
# r += ", "
# else:
# r += ", "
# r += a.tostring(style = style, tosort = tosort)
# if style & PsiOpts.STR_STYLE_PSITIP:
# r += " ])"
# elif style & PsiOpts.STR_STYLE_LATEX:
# if style & PsiOpts.STR_STYLE_LATEX_ARRAY:
# r += "\\end{array}\\right\\]"
# else:
# r += " \\]"
# else:
# r += " ]"
# return r
# def __str__(self):
# return self.tostring(PsiOpts.settings["str_style"], PsiOpts.settings["str_tosort"])
# def __repr__(self):
# return self.tostring(PsiOpts.STR_STYLE_PSITIP)
class ivenn:
def ellipse_contains(el, p):
if len(el) < 4:
el = (el[0], el[1], numpy.linalg.inv(el[1]), True)
elcen, elm, elmi, elinc = el
return not elinc ^ (numpy.linalg.norm(elmi.dot(p - elcen)) <= 1.0)
def ellipse_angle(el, p):
if len(el) < 4:
el = (el[0], el[1], numpy.linalg.inv(el[1]), True)
elcen, elm, elmi, elinc = el
t = elmi.dot(p - elcen)
return numpy.arctan2(t[1], t[0])
def ellipse_clamp(el, p, border, inc, is_it):
if len(el) < 4:
el = (el[0], el[1], numpy.linalg.inv(el[1]), True)
elcen, elm, elmi, elinc = el
pme = p - elcen
t = elmi.dot(pme)
tn = numpy.linalg.norm(t)
if tn <= 1e-9:
return None
if is_it:
ndiv = 600
md = 1e20
mg = numpy.array([0.0, 0.0])
for i in range(ndiv + 1):
ia = i * numpy.pi * 2 / ndiv
g = elm.dot(numpy.array([numpy.cos(ia), numpy.sin(ia)]))
cd = numpy.linalg.norm(g - pme)
if cd < md:
md = cd
mg = g
if md < 1e-9:
return None
if (inc ^ (tn > 1.0)) and md >= border:
return None
if inc ^ (tn > 1.0):
return (pme - mg) * (border / md) + mg + elcen
else:
return -(pme - mg) * (border / md) + mg + elcen
else:
border /= numpy.linalg.norm(pme) / tn
if inc:
if tn <= 1.0 - border:
return None
t = t * (1.0 - border) / tn
else:
if tn >= 1.0 + border:
return None
t = t * (1.0 + border) / tn
return elm.dot(t) + elcen
def ellipse_intersect_it(els, r, ndiv = 500, maxnp = 10000, cp = None, stop0 = None, stop1 = None, stopp = None):
elcen, elm, elmi, elinc = els[0]
started = False
pcontain = True
pcontainxi = -1
angle0 = 0.0
anglesn = 1.0 if elinc else -1.0
if cp is not None:
angle0 = ivenn.ellipse_angle(els[0], cp)
for i in range(ndiv + 1):
ia = anglesn * i * numpy.pi * 2 / ndiv + angle0
p = elm.dot(numpy.array([numpy.cos(ia), numpy.sin(ia)])) + elcen
ccontains = [e is els[0] or ivenn.ellipse_contains(e, p) for e in els]
ccontain = all(ccontains)
ccontainxi = -1
for ei in range(len(els)):
if not ccontains[ei]:
ccontainxi = ei
if cp is None:
if ccontain and not pcontain:
ivenn.ellipse_intersect_it(els, r, ndiv, maxnp, p, els[pcontainxi], els[0], p)
return
else:
if started and not ccontain:
for ei in range(len(els)):
if not ccontains[ei]:
if els[0] is stop0 and els[ei] is stop1 and numpy.linalg.norm(p - stopp) <= 0.02:
return
cels = [els[ei]] + [e for e in els if e is not els[ei]]
ivenn.ellipse_intersect_it(cels, r, ndiv, maxnp, p, stop0, stop1, stopp)
return
return
if ccontain:
started = True
if started:
r.append(p)
if len(r) >= maxnp:
print("Venn diagram intersection failure!")
return
pcontain = ccontain
pcontainxi = ccontainxi
if cp is None:
if len(els) >= 2:
cels = els[1:] + [els[0]]
ivenn.ellipse_intersect_it(cels, r, ndiv, maxnp, cp, stop0, stop1, stopp)
return
# print("Venn diagram includes all!")
for i in range(ndiv):
ia = anglesn * i * numpy.pi * 2 / ndiv + angle0
p = elm.dot(numpy.array([numpy.cos(ia), numpy.sin(ia)])) + elcen
r.append(p)
return
def ellipse_intersect(els, ndiv = 700):
cels = [(e[0], e[1], numpy.linalg.inv(e[1]), e[2] if len(e) >= 3 else True) for e in els]
r = []
ivenn.ellipse_intersect_it(cels, r, ndiv)
return r
def ellipses(n):
if n == 1:
return [(numpy.array([0.0, 0.0]), numpy.eye(2))]
elif n == 2:
ratio = 0.68
return [(numpy.array([ratio - 1.0, 0.0]), numpy.eye(2) * ratio),
(numpy.array([-ratio + 1.0, 0.0]), numpy.eye(2) * ratio)]
elif n == 3:
ratio = 0.65
rad = -ratio + 1.0
return [(numpy.array([-numpy.sin(i * numpy.pi * 2 / 3) * rad,
numpy.cos(i * numpy.pi * 2 / 3) * rad]),
numpy.eye(2) * ratio) for i in range(3)]
elif n == 4:
mtilt0 = numpy.array([[1.0, 0.0], [-0.7, 1.0]])
mtilt1 = numpy.array([[1.0, 0.0], [-0.8, 1.0]])
mtilt2 = numpy.array([[1.0, 0.0], [0.8, 1.0]])
mtilt3 = numpy.array([[1.0, 0.0], [0.7, 1.0]])
return [(numpy.array([-0.4 * 1.011, -0.17]), mtilt0 * 0.7),
(numpy.array([-0.1, 0.15]), mtilt1 * 0.62),
(numpy.array([0.1, 0.15]), mtilt2 * 0.62),
(numpy.array([0.4 * 1.011, -0.17]), mtilt3 * 0.7)]
elif n == 5:
# 5-set Venn diagram by Branko Grunbaum
ratio = 0.85
r = []
for i in range(5):
a = -i * numpy.pi * 2 / 5 + 0.2
cosa = numpy.cos(a)
sina = numpy.sin(a)
rmat = numpy.array([[cosa, -sina], [sina, cosa]])
r.append((rmat.dot(numpy.array([0.04 * 2, 0.087 * 2])) * ratio,
rmat.dot(numpy.array([[0.63, 0.0], [0.0, 1.0]])) * ratio))
return r
def intersect(n, inc):
els = ivenn.ellipses(n)
# ndiv = 500
# if n >= 4:
# ndiv = 700
# return ivenn.ellipse_intersect([e for i, e in enumerate(els) if inc[i]], ndiv = ndiv)
return ivenn.ellipse_intersect([e for i, e in enumerate(els) if inc[i]])
def calc_ellipses(n):
els = ivenn.ellipses(n)
elsi = [(el[0], el[1], numpy.linalg.inv(el[1]), True) for el in els]
r = [[None, None, None] for i in range(1 << n)]
tpolys = [None] * (1 << n)
for mask in range((1 << n) - 1, 0, -1):
maskbc = iutil.bitcount(mask)
# r[mask][0] = ivenn.ellipse_intersect([e for i, e in enumerate(els) if mask & (1 << i)])
r[mask][0] = ivenn.ellipse_intersect(
[(e[0], e[1], bool(mask & (1 << i))) for i, e in enumerate(els)])
if len(r[mask][0]) == 0:
print("Venn diagram intersection empty!")
if False:
cen = sum(r[mask][0], numpy.array([0.0, 0.0])) / len(r[mask][0])
# print(cen)
subcen = numpy.array([0.0, 0.0])
mask_other = ((1 << n) - 1) ^ mask
for mask2 in igen.subset_mask(mask_other):
if mask2 == 0:
continue
mask3 = mask | mask2
subcen += r[mask3][1]
# cborder = 0.55
# if n == 3:
# if maskbc == 2:
# cborder = 0.7
# elif n == 4:
# cborder = 0.8
#cborder = 0.4 * 3.0 / (n + 2.0)
cborder_max = 0.4
nit = 500
cborder_shrink = 0.05
walk_ratio = 1.0
# if n == 4:
# nit += 400
# # nit = 3
# cborder_shrink = 0.975
cborder_step_init = 0.02
cborder_step_shrink = 0.1
if n == 5:
# nit += 400
# nit = 3
cborder_shrink = 0.005
if mask_other != 0:
subcen /= ((1 << iutil.bitcount(mask_other)) - 1)
# r[mask][1] = cen * 5 - subcen * 4
r[mask][1] = cen
# tpoly = ivenn.ellipse_intersect([(e[0], e[1], bool(mask & (1 << i))) for i, e in enumerate(els)], ndiv = 2000)
# r[mask][1] = sum(tpoly, numpy.array([0.0, 0.0])) / len(tpoly)
# tpolys[mask] = tpoly
cborder = 0.0
for it in range(nit):
# cborder = cborder_max * (cborder_shrink ** (it * 1.0 / nit))
cborder_step = cborder_step_init * (cborder_step_shrink ** (it * 1.0 / nit))
tgts = []
for i, e in enumerate(elsi):
t = ivenn.ellipse_clamp(e, r[mask][1], cborder, bool(mask & (1 << i)), n >= 4)
if t is not None:
tgts.append(t)
if len(tgts):
r[mask][1] = (sum(tgts, numpy.array([0.0, 0.0])) / len(tgts)) * walk_ratio + r[mask][1] * (1.0 - walk_ratio)
cborder = max(cborder - cborder_step, 0.0)
# print(" " + str(it) + ": " + str(cborder) + " " + str(r[mask][1]))
else:
cborder += cborder_step
else:
r[mask][1] = cen
# r[mask][1] = cen * (1 << iutil.bitcount(mask_other)) - subcen
# print(r[mask][1])
print("r[" + str(mask) + "][1] = numpy." + repr(r[mask][1]))
for i, e in enumerate(elsi):
# if ivenn.ellipse_contains(e, r[mask][1]) ^ bool(mask & (1 << i)):
# print("FAIL " + str(i))
if ivenn.ellipse_clamp(e, r[mask][1], 0.0, bool(mask & (1 << i)), n >= 4) is not None:
print("FAIL " + str(i))
r[mask][2] = numpy.array([0.0, 0.0])
textps = []
if n == 1:
textps = [(0.0, 1.1)]
elif n == 2:
textps = [(-0.45, 0.8), (0.45, 0.8)]
elif n == 3:
textps = [(0.0, 1.1), (-1.0, -0.6), (1.0, -0.6)]
elif n == 4:
textps = [(-0.95, 0.85), (-0.4, 1.1), (0.4, 1.1), (0.95, 0.85)]
elif n == 5:
textps = [(-numpy.sin(-i * numpy.pi * 2 / 5 + 0.12) * 1.1,
numpy.cos(-i * numpy.pi * 2 / 5 + 0.12) * 1.1) for i in range(5)]
if True:
if n == 1:
r[1][1] = numpy.array([0.0, 0.0])
elif n == 2:
r[3][1] = numpy.array([-5.19979865e-06, 0.0])
r[2][1] = numpy.array([6.39999982e-01, 0.0])
r[1][1] = numpy.array([-6.39999982e-01, 0.0])
elif n == 3:
r[7][1] = numpy.array([-0.00011776, 0.00020427])
r[6][1] = numpy.array([ 8.10570556e-05, -4.61882452e-01])
r[5][1] = numpy.array([0.40013006, 0.2308303 ])
r[4][1] = numpy.array([ 0.60932452, -0.35185475])
r[3][1] = numpy.array([-0.40005325, 0.23087825])
r[2][1] = numpy.array([-0.6093162 , -0.35186916])
r[1][1] = numpy.array([1.71157142e-05, 7.03617892e-01])
elif n == 4:
r[15][1] = numpy.array([ 0.0, -0.24960125])
r[14][1] = numpy.array([0.24123917, 0.14045396 - 0.04])
r[13][1] = numpy.array([-0.18729334, -0.53280546])
r[12][1] = numpy.array([0.51662578, 0.40616016])
r[11][1] = numpy.array([ 0.19303297, -0.51624427])
r[10][1] = numpy.array([ 0.41826399 + 0.01, -0.41951261])
r[9][1] = numpy.array([ 0.0, -0.78949216])
r[8][1] = numpy.array([0.85559034, 0.22166595 - 0.02])
r[7][1] = numpy.array([-0.24252916, 0.14045396 - 0.04])
r[6][1] = numpy.array([0.0, 0.36160384 + 0.04])
r[5][1] = numpy.array([-0.41826399 - 0.01, -0.41951261])
r[4][1] = numpy.array([0.40438341, 0.74829018])
r[3][1] = numpy.array([-0.51706528, 0.40634087])
r[2][1] = numpy.array([-0.4085948 , 0.74881886])
r[1][1] = numpy.array([-0.85559034, 0.22166595 - 0.02])
elif n == 5:
r[31][1] = numpy.array([ 0.00123121, -0.00035098])
r[30][1] = numpy.array([-0.50522841, -0.15738005])
r[29][1] = numpy.array([-0.30334848, 0.43405041])
r[28][1] = numpy.array([-0.5759732 , 0.08623396])
r[27][1] = numpy.array([0.31906663, 0.42263051])
r[26][1] = numpy.array([-0.56718932, -0.29605626])
r[25][1] = numpy.array([-0.09599438, 0.57443462])
r[24][1] = numpy.array([-0.66609369, -0.20490312])
r[23][1] = numpy.array([ 0.49741082, -0.18604116])
r[22][1] = numpy.array([ 0.63288148, -0.09387176])
r[21][1] = numpy.array([-0.45683543, 0.44794225])
r[20][1] = numpy.array([-0.57279059, 0.30624 ])
r[19][1] = numpy.array([0.51664827, 0.26884256])
r[18][1] = numpy.array([0.64339464, 0.0889563 ])
r[17][1] = numpy.array([-0.36041924, 0.57743233])
r[16][1] = numpy.array([-0.8230906 , 0.14612619])
r[15][1] = numpy.array([-0.00456696, -0.52900733])
r[14][1] = numpy.array([-0.25995857, -0.52117389])
r[13][1] = numpy.array([ 0.10629509, -0.63091551])
r[12][1] = numpy.array([-0.01095303, -0.69681429])
r[11][1] = numpy.array([0.28484849, 0.57289808])
r[10][1] = numpy.array([-0.46825182, -0.45012615])
r[9][1] = numpy.array([0.11424952, 0.63938959])
r[8][1] = numpy.array([-0.39334952, -0.73764132])
r[7][1] = numpy.array([ 0.41524152, -0.40832889])
r[6][1] = numpy.array([ 0.63606 , -0.25034334])
r[5][1] = numpy.array([ 0.28339764, -0.58443058])
r[4][1] = numpy.array([ 0.5799869 , -0.60204133])
r[3][1] = numpy.array([0.43456033, 0.52772247])
r[2][1] = numpy.array([0.75179328, 0.36557041])
r[1][1] = numpy.array([-0.11537642, 0.82796063])
if False:
for mask in range(1, (1 << n) - 1):
maski = mask
for i in range(1, 5):
maski = maski << 1
if maski & (1 << n):
maski -= 1 << n
maski += 1
a = i * numpy.pi * 2 / 5
cosa = numpy.cos(a)
sina = numpy.sin(a)
r[maski][1] = numpy.array([[cosa, -sina], [sina, cosa]]).dot(r[mask][1])
for mask in range((1 << n) - 1, 0, -1):
mdist = 1e10
for x in r[mask][0]:
cdist = x[1]
cdist += numpy.linalg.norm(r[mask][1] - x)
if cdist < mdist:
mdist = cdist
r[mask][2] = x
return (r, textps)
def patch(n, inc, **kwargs):
return matplotlib.patches.Polygon(ivenn.intersect(n, inc), closed=True, **kwargs)
class CellTable:
def __init__(self, x):
self.x = x
self.cells = [{} for i in range(1 << len(self.x))]
self.fontsize = 22
self.linewidth = 1.5
self.exprs = []
def set_enabled(self, mask, enabled = True):
self.cells[mask]["enabled"] = enabled
def get_enabled(self, mask):
return self.cells[mask].get("enabled", True)
def set_attr(self, mask, key, val):
self.cells[mask][key] = val
def get_attr(self, mask, key, default_val = None):
return self.cells[mask].get(key, default_val)
def add_expr(self, expr, cval):
self.exprs.append({"expr": expr, "cval": cval})
def set_expr_val(self, mask, val):
self.cells[mask]["val_" + str(len(self.exprs) - 1)] = val
def get_pos(self, mask):
n = len(self.x)
nv = n // 2
maskv = mask & ((1 << nv) - 1)
maskh = mask >> nv
return (iutil.gray_to_bin(maskh), iutil.gray_to_bin(iutil.bit_reverse(maskv, nv)))
def get_x_poss(self, xit):
n = len(self.x)
nv = n // 2
cn = nv
ax = 1
if xit >= nv:
xit -= nv
ax = 0
cn = n - nv
else:
xit = nv - 1 - xit
r = []
for mask in range(1, 1 << cn):
if iutil.bin_to_gray(mask) & (1 << xit):
r.append(mask)
return (ax, r)
@staticmethod
def get_color(x):
if x is None:
return None
if isinstance(x, str):
mode = None
x = x.lower().strip()
if x.startswith("l_"):
x = x[2:]
mode = "l"
elif x.startswith("d_"):
x = x[2:]
mode = "d"
colors = dict(matplotlib.colors.BASE_COLORS, **matplotlib.colors.CSS4_COLORS)
colors["r"] = (1.0, 0.0, 0.0)
colors["g"] = (0.0, 1.0, 0.0)
colors["b"] = (0.0, 0.0, 1.0)
colors["c"] = (0.0, 1.0, 1.0)
colors["m"] = (1.0, 0.0, 1.0)
colors["y"] = (1.0, 1.0, 0.0)
colors["k"] = (0.0, 0.0, 0.0)
colors["w"] = (1.0, 1.0, 1.0)
if x not in colors:
return None
r = colors[x]
if not isinstance(r, tuple):
r = tuple(matplotlib.colors.to_rgba(r)[:3])
if mode == "l":
return (r[0] * 0.5 + 0.5, r[1] * 0.5 + 0.5, r[2] * 0.5 + 0.5)
elif mode == "d":
return (r[0] * 0.5, r[1] * 0.5, r[2] * 0.5)
else:
return r
return x
@staticmethod
def color_blend(cols, style):
if len(cols) == 0:
return (0.0, 0.0, 0.0)
if len(cols) == 1:
return tuple(cols[0])
if style == "avghsv":
r = [CellTable.color_blend(cols, "hsv"), CellTable.color_blend(cols, "avg")]
return tuple([sum(c[i] for c in r) / len(r) for i in range(3)])
if style == "hsv":
hsv = [matplotlib.colors.rgb_to_hsv(c) for c in cols]
s = sum(t[1] for t in hsv) / len(hsv)
v = sum(t[2] for t in hsv) / len(hsv)
hvecs = [(math.cos((t[0] - 0.5) * 2 * math.pi) * t[1] * t[2],
math.sin((t[0] - 0.5) * 2 * math.pi) * t[1] * t[2]) for t in hsv]
hvec = [sum(t[i] for t in hvecs) / len(hvecs) for i in range(2)]
h = 0.0
if hvec[0]**2 + hvec[1]**2 > 0.03**2:
h = math.atan2(hvec[1], hvec[0]) / (2 * math.pi) + 0.5
else:
s = 0.0
return matplotlib.colors.hsv_to_rgb((h, s, v))
return tuple([sum(c[i] for c in cols) / len(cols) for i in range(3)])
def get_expr_color(self, i, color_shift):
r = None
if i is not None:
r = CellTable.get_color(self.exprs[i].get("color", None))
if r is None:
# return [(1, 0.5, 0.5), (0.5, 1, 0.5), (0.5, 0.5, 1), (1, 1, 0.2), (1, 0.3, 1), (0.3, 1, 1),
# (1, 0.8, 0.3), (0.6, 0.6, 0.6)][i]
# clist = [(1, 0.5, 0.5), (0.5, 0.5, 1), (0.5, 1, 0.5), (0.6, 0.6, 0.6), (1, 0.3, 1), (1, 1, 0.2), (0.3, 1, 1),
# (1, 0.8, 0.3)]
clist = [(1, 0.5, 0.5), (0.5, 0.5, 1), (0.5, 1, 0.5), (1, 1, 0.2), (1, 0.3, 1), (0.3, 1, 1),
(1, 0.8, 0.3), (0.6, 0.6, 0.6)]
clist += [tuple(max(t[i] - 0.3, 0.0) for i in range(3)) for t in clist]
cid = (0 if i is None else i) + color_shift
return clist[cid % len(clist)]
return r
def plot(self, style = "hsplit", legend = True, use_latex = True, figsize = None, color = None, cval_color = None, lib = "matplotlib"):
lib = PsiOpts.get_plotlib(lib)
use_mathregular = False
if isinstance(use_latex, str) and use_latex == "mathregular":
use_mathregular = True
use_latex = True
label_interval = 0.32
label_width = 0.29
fontsize = self.fontsize
fontsize_in_mul = 1.0
linewidth = self.linewidth
if figsize is None:
figsize = PsiOpts.settings["figsize"]
if figsize is None:
figsize = [10, 8]
style = PsiOpts.settings.get("venn_style", "") + "," + style
if color is None:
color = []
elif isinstance(color, str):
color = color.split(",")
color_list = [CellTable.get_color(c) for c in color]
# print(color_list)
rcParams_orig = None
if use_latex and not use_mathregular:
rcParams_orig = plt.rcParams.copy()
plt.rcParams.update({"text.usetex": True,
"font.family": "sans-serif",
"font.sans-serif": ["Helvetica"]})
fig, ax = plt.subplots(figsize = figsize)
# patches = []
xlim = [0, 0]
ylim = [0, 0]
expr_hatch = [None] * len(self.exprs)
mask_nexpr = [0] * (1 << len(self.x))
hatches = ['//', '\\\\', '||', '--', 'o', 'O', '.', '*', '+', 'x']
cval_present = False
cval_min = 0.0 # 1e20
cval_max = 0.0 # -1e20
for mask in range(1, 1 << len(self.x)):
cval = self.get_attr(mask, "cval")
if cval is not None:
cval = float(cval)
cval_present = True
cval_min = min(cval_min, cval)
cval_max = max(cval_max, cval)
rect_draw = True
pm = False
text_draw = True
val_outline = None
neg_hatch = None
is_blend = None
blend_style = "avghsv"
is_venn = False
is_venn_overlap = False
text_sub_draw = True
neg_hatch_style = "//"
hatch_all = True
color_shift = 0
numdp = 4
legend_loc = "right"
cval_color_enabled = True
cval_ignore = False
style_split = style.split(",")
style = None
for cstyle2 in style_split:
cstyle = cstyle2.strip()
if len(cstyle) == 0:
continue
if cstyle == "nofill":
rect_draw = False
elif cstyle == "pm":
pm = True
elif cstyle == "num":
pm = False
elif cstyle == "text":
text_draw = True
elif cstyle == "notext":
text_draw = False
elif cstyle == "sign":
text_sub_draw = False
elif cstyle == "nosign":
text_sub_draw = False
elif cstyle == "signhatch":
neg_hatch = True
elif cstyle == "nosignhatch":
neg_hatch = False
elif cstyle == "legend":
legend = True
elif cstyle == "nolegend":
legend = False
elif cstyle == "legend_left":
legend = True
legend_loc = "left"
elif cstyle == "legend_right":
legend = True
legend_loc = "right"
elif cstyle == "legend_top":
legend = True
legend_loc = "top"
elif cstyle == "legend_bottom":
legend = True
legend_loc = "bottom"
elif cstyle == "val_outline":
val_outline = True
elif cstyle == "val_nooutline":
val_outline = False
elif cstyle == "nocval":
cval_ignore = True
elif cstyle == "nocval_color":
cval_color_enabled = False
elif cstyle == "dp0":
numdp = 0
elif cstyle == "dp1":
numdp = 1
elif cstyle == "dp2":
numdp = 2
elif cstyle == "dp3":
numdp = 3
elif cstyle == "dp4":
numdp = 4
elif cstyle == "dp5":
numdp = 5
elif cstyle == "dp6":
numdp = 6
elif cstyle == "dp7":
numdp = 7
elif cstyle == "dp8":
numdp = 8
elif cstyle == "dp9":
numdp = 9
elif cstyle == "venn":
is_venn = True
elif cstyle == "blend_hsv":
style = "blend"
blend_style = "hsv"
elif cstyle == "blend_avg":
style = "blend"
blend_style = "avg"
elif cstyle == "blend_avghsv":
style = "blend"
blend_style = "avghsv"
else:
style = cstyle
if style is None:
style = "hsplit"
if cval_ignore:
cval_present = False
if cval_min >= cval_max:
cval_color_enabled = False
if cval_present:
if cval_color_enabled:
if cval_color is None:
cval_color = self.get_expr_color(None, color_shift)
color_shift += 1
else:
cval_color = CellTable.get_color(cval_color)
else:
cval_color_enabled = False
if cval_color_enabled:
style = "hatch"
if len(self.x) > 5:
is_venn = False
if is_venn:
if len(self.x) == 4:
fontsize_in_mul = 0.9
if len(self.x) == 5:
fontsize_in_mul = 0.8
if val_outline is None:
val_outline = style == "hatch" or style == "text" or style == "blend"
if neg_hatch is None:
neg_hatch = style == "hsplit_fixed" or style == "hsplit" or style == "blend"
if is_blend is None:
is_blend = style == "blend"
for mask in range(1, 1 << len(self.x)):
expr_pres = [ei for ei in range(len(self.exprs))
if self.get_attr(mask, "val_" + str(ei)) is not None]
mask_nexpr[mask] = len(expr_pres)
if style == "hatch":
if len(expr_pres) >= 2:
for ei in expr_pres:
expr_hatch[ei] = True
if style == "hatch":
if hatch_all or cval_color_enabled:
for ei in range(len(self.exprs)):
expr_hatch[ei] = True
nover = 0
for ei in range(len(self.exprs)):
if expr_hatch[ei]:
expr_hatch[ei] = hatches[nover % len(hatches)]
nover += 1
n = len(self.x)
axlen = [0, 0]
axslen = [0, 0]
xstr = []
for xit in range(n):
if use_latex:
if use_mathregular:
xstr.append(r"$\mathregular{" + self.x[xit].latex() + r"}$")
else:
xstr.append("$" + self.x[xit].latex() + "$")
else:
xstr.append(str(self.x[xit]))
polys = [None] * (1 << n)
poly_els = [None] * n
textps = [None] * n
if is_venn:
polys, textps = ivenn.calc_ellipses(n)
poly_els = [ivenn.intersect(n, [i == j for j in range(n)]) for i in range(n)]
for xit in range(n):
ax.text(textps[xit][0], textps[xit][1], xstr[xit],
horizontalalignment="center", verticalalignment="center", fontsize = fontsize)
if not is_venn:
for xit in range(n):
cax, clist = self.get_x_poss(xit)
axslen[cax] += 1
for xit in range(n):
cax, clist = self.get_x_poss(xit)
avgpos = [0.0, 0.0]
caxlen = axlen[cax]
if cax:
caxlen = axslen[cax] - 1 - caxlen
for v in clist:
pos = (v, -(caxlen + 1) * label_interval)
size = (1, label_width)
if cax:
pos = (pos[1], pos[0])
size = (size[1], size[0])
pos = (pos[0], -pos[1] - size[1])
avgpos[0] += pos[0] + size[0] * 0.5
avgpos[1] += pos[1] + size[1] * 0.5
xlim[0] = min(xlim[0], pos[0])
xlim[1] = max(xlim[1], pos[0] + size[0])
ylim[0] = min(ylim[0], pos[1])
ylim[1] = max(ylim[1], pos[1] + size[1])
ax.add_patch(matplotlib.patches.Rectangle(
pos, size[0], size[1], facecolor = "lightgray"))
avgpos[0] /= len(clist)
avgpos[1] /= len(clist)
ax.text(avgpos[0], avgpos[1], xstr[xit],
horizontalalignment="center", verticalalignment="center", fontsize = fontsize)
axlen[cax] += 1
# passes = [0, 4]
# if is_venn:
passes = [0, 2, 4]
for cpass in passes:
for mask in range(1, 1 << len(self.x)):
if cpass == 2:
if is_venn and iutil.bitcount(mask) != 1:
continue
pos = None
size = None
if is_venn:
pos = polys[mask][1]
size = 1.2 / n
if n == 4:
size *= 0.95
elif n == 5:
nbit = iutil.bitcount(mask)
if nbit >= 2 and nbit <= 4:
size *= 0.55
elif nbit == 5:
size *= 1.5
else:
size *= 1.2
size = (size, size)
pos = (pos[0] - size[0] * 0.5, pos[1] - size[1] * 0.5)
else:
pos = self.get_pos(mask)
pos = (pos[0], -pos[1] - 1)
size = (1, 1)
xlim[0] = min(xlim[0], pos[0])
xlim[1] = max(xlim[1], pos[0] + size[0])
ylim[0] = min(ylim[0], pos[1])
ylim[1] = max(ylim[1], pos[1] + size[1])
isenabled = self.get_enabled(mask)
cval = None
if cval_present:
cval = self.get_attr(mask, "cval")
if cval is not None:
cval = float(cval)
color = "none"
if not isenabled:
color = "k"
elif (cval is not None) and cval_color_enabled:
cval_scaled = (cval - cval_min) / (cval_max - cval_min)
color = tuple(cval_color[i] * cval_scaled + 1.0 - cval_scaled for i in range(3))
if cpass == 0 and color != "none":
params = {
"facecolor": color
}
if is_venn:
ax.add_patch(matplotlib.patches.Polygon(
polys[mask][0], closed=True, **params))
else:
ax.add_patch(matplotlib.patches.Rectangle(
pos, size[0], size[1], **params))
if cpass == 0:
params = {
"facecolor": "white",
"edgecolor": "none"
}
if is_venn and is_venn_overlap:
ax.add_patch(matplotlib.patches.Polygon(
polys[mask][0], closed=True, **params))
if cpass == 2:
params = {
"facecolor": "none",
"linestyle": "-",
"linewidth": linewidth,
"edgecolor": "k"
}
if is_venn:
# ax.add_patch(matplotlib.patches.Polygon(
# polys[mask][0], closed=True, **params))
for i in range(n):
if mask & (1 << i):
ax.add_patch(matplotlib.patches.Polygon(
poly_els[i], closed=True, **params))
break
else:
ax.add_patch(matplotlib.patches.Rectangle(
pos, size[0], size[1], **params))
continue
if isenabled:
cnexpr = 0
snexpr = mask_nexpr[mask]
colsum = [0.0, 0.0, 0.0]
colsumsol = [0.0, 0.0, 0.0]
collist = []
collistsol = []
nsol = 0
if cval is not None:
if cpass == 4:
if text_draw:
text_y = 0.7
if len(self.exprs) == 0:
text_y = 0.5
ctext = ("{:." + str(numdp) + "f}").format(cval)
ax.text(pos[0] + 0.5 * size[0], pos[1] + size[1] * text_y, ctext,
horizontalalignment="center", verticalalignment="center", fontsize = fontsize * fontsize_in_mul,
color = "k")
for ei in range(len(self.exprs)):
v = self.get_attr(mask, "val_" + str(ei))
if v is None:
continue
if pm:
ctext = iutil.float_tostr(abs(v), bracket = False)
if ctext == "1":
ctext = ""
if v >= 0:
ctext = "+" + ctext
else:
ctext = "-" + ctext
else:
ctext = iutil.float_tostr(v, bracket = False)
ccolor = None
if ei < len(color_list):
ccolor = color_list[ei]
else:
ccolor = self.get_expr_color(ei, color_shift)
chatch = expr_hatch[ei]
hatch_invert = False
if chatch is not None:
hatch_invert = True
if neg_hatch and v < 0:
chatch = neg_hatch_style
rect_x = 0.0
rect_w = 1.0
text_x = (cnexpr + 0.5) * 1.0 / snexpr
text_y = 0.5
if cval is not None:
text_y = 0.3
if style == "hsplit_fixed":
rect_x = ei * 1.0 / len(self.exprs)
rect_w = 1.0 / len(self.exprs)
text_x = (ei + 0.5) * 1.0 / len(self.exprs)
elif style == "hsplit":
rect_x = cnexpr * 1.0 / snexpr
rect_w = 1.0 / snexpr
if isinstance(ccolor, tuple):
colsum[0] += ccolor[0]
colsum[1] += ccolor[1]
colsum[2] += ccolor[2]
collist.append(ccolor)
if not (neg_hatch and v < 0):
colsumsol[0] += ccolor[0]
colsumsol[1] += ccolor[1]
colsumsol[2] += ccolor[2]
collistsol.append(ccolor)
nsol += 1
if cpass == 0:
if rect_draw and (not is_blend or cnexpr == snexpr - 1):
hatch_col = (1.0, 1.0, 1.0)
if is_blend and cnexpr == snexpr - 1:
# ccolor = (colsum[0] / snexpr, colsum[1] / snexpr, colsum[2] / snexpr)
ccolor = CellTable.color_blend(collist, blend_style)
if nsol > 0:
# hatch_col = (colsumsol[0] / nsol, colsumsol[1] / nsol, colsumsol[2] / nsol)
hatch_col = CellTable.color_blend(collistsol, blend_style)
if abs(ccolor[0] - hatch_col[0]) + abs(ccolor[1] - hatch_col[1]) + abs(ccolor[2] - hatch_col[2]) <= 0.001:
chatch = None
else:
chatch = neg_hatch_style
params = {
"facecolor": "none" if hatch_invert else ccolor,
"hatch": chatch,
"edgecolor": ccolor if hatch_invert else hatch_col if chatch else "none",
"linewidth": 0
}
if is_venn:
# print(mask)
# print(polys[mask])
ax.add_patch(matplotlib.patches.Polygon(
polys[mask][0], closed=True, **params))
else:
ax.add_patch(matplotlib.patches.Rectangle(
(pos[0] + rect_x * size[0], pos[1]),
size[0] * rect_w, size[1], **params))
# ax.add_patch(matplotlib.patches.Rectangle(
# (pos[0] + rect_x * size[0], pos[1]),
# size[0] * rect_w, size[1],
# facecolor = "none" if hatch_invert else ccolor, hatch = chatch,
# edgecolor = ccolor if hatch_invert else "white" if chatch else "none",
# linewidth = 0))
elif cpass == 4:
if text_draw:
ax.text(pos[0] + text_x * size[0], pos[1] + size[1] * text_y, ctext,
horizontalalignment="center", verticalalignment="center", fontsize = fontsize * fontsize_in_mul,
color = ccolor if val_outline else "k",
path_effects = [matplotlib.patheffects.withStroke(linewidth = 3.5, foreground = "k")] if val_outline else None)
cnexpr += 1
if cpass == 0:
remtext = ""
if self.get_attr(mask, "ispos"):
remtext += "+"
if self.get_attr(mask, "isneg"):
remtext += "-"
if text_sub_draw and len(remtext):
rempos = None
if is_venn:
# vpos = -0.08
# if n == 5:
# vpos = 0.05
# rempos = (pos[0] + size[0] * 0.5, pos[1] + size[1] * vpos)
rempos = polys[mask][2]
rempos = (rempos[0], rempos[1] + size[1] * 0.033)
else:
rempos = (pos[0] + size[0] * 0.015, pos[1] + size[1] * 0.02)
ax.text(rempos[0], rempos[1], remtext,
horizontalalignment="center" if is_venn else "left", verticalalignment="bottom",
fontsize = fontsize * fontsize_in_mul * 0.85)
if is_venn:
xlim = [-1.1, 1.1]
ylim = [-1.1, 1.1]
if legend and len(self.exprs):
legends = []
for ei in range(len(self.exprs)):
ccolor = None
if ei < len(color_list):
ccolor = color_list[ei]
else:
ccolor = self.get_expr_color(ei, color_shift)
clabel = ""
cexpr = self.exprs[ei].get("expr")
if cexpr is not None:
with PsiOpts(str_lhsreal = False):
if use_latex:
if use_mathregular:
clabel = r"$\mathregular{" + cexpr.latex(skip_simplify = True) + r"}$"
else:
clabel = "$" + cexpr.latex(skip_simplify = True) + "$"
else:
clabel = str(cexpr)
# if isinstance(cexpr, Region):
# clabel = cexpr.tostring(lhsvar = None)
# else:
# clabel = str(cexpr)
cval = self.exprs[ei].get("cval")
if cval_present and (cval is not None):
if not isinstance(cval, bool):
cval = float(cval)
if isinstance(cval, float):
clabel += " = " + ("{:." + str(numdp) + "f}").format(cval)
else:
clabel += " = " + str(cval)
if not len(clabel):
continue
# legends.append(matplotlib.lines.Line2D([0], [0], color=ccolor, lw=4, label=clabel))
legends.append(matplotlib.patches.Patch(facecolor=ccolor, label=clabel))
bbox_buffer = 0.03
bbox_to_anchor = (0.5, -bbox_buffer)
loc = "upper center"
if legend_loc == "top":
bbox_to_anchor = (0.5, 1.0 + bbox_buffer)
loc = "lower center"
elif legend_loc == "right":
bbox_to_anchor = (1.0 + bbox_buffer, 0.5)
loc = "center left"
elif legend_loc == "left":
bbox_to_anchor = (-bbox_buffer, 0.5)
loc = "center right"
ax.legend(handles = legends, fontsize = fontsize, bbox_to_anchor=bbox_to_anchor,
loc=loc, frameon=False)
if is_venn:
ax.set_aspect("equal")
# ax.add_collection(PatchCollection(patches))
plt.axis("off")
plt.xlim([xlim[0] - 0.011, xlim[1] + 0.011])
plt.ylim([ylim[0] - 0.011, ylim[1] + 0.011])
if lib == "plotly":
plotly.offline.plot_mpl(fig)
else:
plt.show()
fig.tight_layout()
if use_latex and not use_mathregular:
plt.rcParams = rcParams_orig.copy()
class ProofObj(IBaseObj):
def __init__(self, claim, desc = None, steps = None, parent = None, meta = None):
self.claim = claim
self.desc = desc
if steps is None:
steps = []
self.steps = steps
self.parent = parent
if meta is None:
meta = {}
self.meta = meta
@staticmethod
def empty():
return ProofObj(None, None, [])
def isempty(self):
return self.claim is None and self.desc is None and len(self.steps) == 0
def copy(self):
# return ProofObj([(x.copy(), list(d), c) for x, d, c in self.steps])
return ProofObj(iutil.copy(self.claim), iutil.copy(self.desc),
iutil.copy(self.steps), self.parent, iutil.copy(self.meta))
def copy_(self, other):
self.claim = iutil.copy(other.claim)
self.desc = iutil.copy(other.desc)
self.steps = iutil.copy(other.steps)
self.parent = other.parent
self.meta = iutil.copy(other.meta)
@staticmethod
def from_region(x, c = ""):
# return ProofObj([(x.copy(), [], c)])
return ProofObj(iutil.copy(x), c)
@staticmethod
def from_chain(x):
ineqchain = []
i = 0
while i < len(x):
ceqnstr = ""
if i > 0:
ceqnstr = x[i]
i += 1
if i >= len(x):
break
csum = x[i]
i += 1
cclaim = []
if i < len(x) and not isinstance(x[i], str):
cclaim = x[i]
i += 1
ineqchain.append([csum, ceqnstr, cclaim])
return ProofObj.from_region(ineqchain, c = "Steps: ")
def get_case(self):
if "case" in self.meta:
return self.meta["case"]
return Region.universe()
def set_case(self, reg):
self.meta["case"] = reg
def __iadd__(self, other):
# n = len(self.steps)
# self.steps += [(x.copy(), [di + n for di in d], c) for x, d, c in other.steps]
other = other.copy()
other.parent = self
self.steps.append(other)
return self
def insert_step(self, other, pos):
other = other.copy()
other.parent = self
self.steps.insert(pos, other)
return self
def __add__(self, other):
r = self.copy()
r += other
return r
def step_in(self, other):
other = other.copy()
other.parent = self
self.steps.append(other)
return other
def step_out(self):
return self.parent
def clear(self):
self.steps = []
# def tostring(self, style = 0, prev = None):
# """Convert to string.
# Parameters:
# style : Style of string conversion
# STR_STYLE_STANDARD : I(X,Y;Z|W)
# STR_STYLE_PSITIP : I(X+Y&Z|W)
# """
# style = iutil.convert_str_style(style)
# r = ""
# start_n = 0
# if prev is not None:
# start_n = len(prev.steps)
# for i, (x, d, c) in enumerate(self.steps):
# if i > 0:
# r += "\n"
# r += "STEP #" + str(start_n + i)
# if c != "":
# r += " " + c
# if x is None or x.isuniverse():
# r += "\n"
# else:
# r += ":\n"
# r += x.tostring(style = style)
# r += "\n"
# return r
def step_regions(self, flat = True):
r = []
if self.desc is None:
pass
else:
if self.claim is not None:
if isinstance(self.claim, tuple):
if self.claim[0] == "equiv":
r.append(equiv(self.claim[1], self.claim[2]))
elif self.claim[0] == "implies":
r.append(self.claim[1] >> self.claim[2])
else:
if isinstance(self.claim, list):
isle = any(eqnstr == "<=" for expr, eqnstr, note in self.claim)
prev = None
for i, a in enumerate(self.claim):
expr, eqnstr, note = a
if prev is not None:
r.append(iutil.op_str(eqnstr, prev, expr, isle = isle))
prev = expr
else:
r.append(self.claim)
for i, x in enumerate(self.steps):
t = x.step_regions(flat=flat)
if flat:
r += t
else:
r.append(t)
return r
def tostring(self, style = 0, prefix = "", inden = 0, skip_env = False, outermost = True):
style = iutil.convert_str_style(style)
slstr = ""
nlstr = "\n"
if style & PsiOpts.STR_STYLE_LATEX:
slstr = "&"
nlstr = "\\\\\n"
spacestr = " "
if style & PsiOpts.STR_STYLE_LATEX:
spacestr = "\\;"
slstr = slstr + spacestr * inden
inden2 = inden
if self.desc is not None:
inden2 += 2
r = ""
cinden = 4
cprefix = prefix
if style & PsiOpts.STR_STYLE_LATEX:
if not skip_env:
r += "\\begin{align*}\n"
if self.desc is None:
cinden = 0
else:
if not outermost or self.steps:
outermost = False
r += slstr
if prefix:
r += prefix + "." + spacestr
cprefix += "."
r += iutil.tostring_join(self.desc, style, nlstr = nlstr + slstr + spacestr * 8)
if self.claim is not None:
r += nlstr
if self.claim is not None:
if isinstance(self.claim, tuple):
if self.claim[0] == "equiv":
r += slstr + self.claim[1].tostring(style = style)
if style & PsiOpts.STR_STYLE_PSITIP:
r += nlstr + slstr + "==" + nlstr
elif style & PsiOpts.STR_STYLE_LATEX:
r += nlstr + slstr + PsiOpts.settings["latex_equiv"] + nlstr
else:
r += nlstr + slstr + "<=>" + nlstr
r += slstr + self.claim[2].tostring(style = style)
elif self.claim[0] == "implies":
r += slstr + self.claim[1].tostring(style = style)
if style & PsiOpts.STR_STYLE_PSITIP:
r += nlstr + slstr + ">>" + nlstr
elif style & PsiOpts.STR_STYLE_LATEX:
r += nlstr + slstr + PsiOpts.settings["latex_implies"] + nlstr
else:
r += nlstr + slstr + "=>" + nlstr
r += slstr + self.claim[2].tostring(style = style)
else:
if isinstance(self.claim, list):
note_color = None
note_newline = False
if "note_color" in self.meta:
note_color = self.meta["note_color"]
if "note_newline" in self.meta:
note_newline = self.meta["note_newline"]
if style & PsiOpts.STR_STYLE_LATEX and PsiOpts.settings["latex_line_len"] is not None:
if note_newline is False:
note_newline = 100000000
if note_newline is not True:
note_newline = min(note_newline, PsiOpts.settings["latex_line_len"])
# print(self.claim)
for i, a in enumerate(self.claim):
expr, eqnstr, note = a
if i:
r += nlstr
cline = expr.tostring_line_len(eqnstr = eqnstr, style = style)
r += slstr + cline
if note is not None and not (isinstance(note, list) and len(note) == 0):
cnote = iutil.tostring_join(note, style)
cnumspace = 3
if note_newline is True or (note_newline is not False and
isinstance(note_newline, int) and iutil.latex_len(cline + cnote) > note_newline):
r += nlstr + slstr
cnumspace = 8
if style & PsiOpts.STR_STYLE_LATEX:
r += "\\;" * cnumspace
if note_color is not None:
r += "{\\color{" + str(note_color) + "}{"
else:
r += " " * cnumspace
r += cnote
if style & PsiOpts.STR_STYLE_LATEX:
if note_color is not None:
r += "}}"
else:
r += slstr + self.claim.tostring(style = style)
r += nlstr
if self.steps:
r += nlstr
if len(self.steps) >= 2:
outermost = False
for i, x in enumerate(self.steps):
if i:
r += nlstr
# r += iutil.str_inden(x.tostring(style, cprefix + str(i + 1), skip_env = True), cinden, spacestr = spacestr, slstr = slstr)
r += x.tostring(style, cprefix + str(i + 1), inden = inden2, skip_env = True, outermost = outermost)
if style & PsiOpts.STR_STYLE_LATEX:
if not skip_env:
r += "\\end{align*}\n"
return r
def __str__(self):
return self.tostring(PsiOpts.settings["str_style"])
@latex_postprocess
def _latex_(self):
return self.tostring(iutil.convert_str_style("latex"))
class CodingNode(IBaseObj):
def __init__(self, rv_out, aux_out = None, aux_dec = None, aux_ndec = None, rv_in_causal = None, rv_in_scausal = None, ndec_mode = None, label = None, rv_ndec_force = None):
self.rv_out = rv_out
if isinstance(aux_out, tuple):
aux_out = [aux_out]
if isinstance(aux_out, list):
self.aux_out_sublist = aux_out
self.aux_out = Comp.empty()
for v0, v1 in self.aux_out_sublist:
self.aux_out += v0
else:
self.aux_out_sublist = []
self.aux_out = aux_out
self.aux_dec = aux_dec
self.aux_ndec = aux_ndec
if rv_in_causal is None:
self.rv_in_causal = Comp.empty()
else:
self.rv_in_causal = rv_in_causal.copy()
if rv_in_scausal is None:
self.rv_in_scausal = Comp.empty()
else:
self.rv_in_scausal = rv_in_scausal.copy()
self.aux_ndec_try = Comp.empty()
self.aux_ndec_force = Comp.empty()
self.aux_borrow = Comp.empty()
if rv_ndec_force is None:
self.rv_ndec_force = Comp.empty()
else:
self.rv_ndec_force = rv_ndec_force.copy()
self.ndec_mode = ndec_mode
self.label = label
def copy(self):
r = CodingNode(self.rv_out.copy())
r.aux_out = iutil.copy(self.aux_out)
r.aux_out_sublist = [(a.copy(), b.copy()) for a, b in self.aux_out_sublist]
r.aux_dec = iutil.copy(self.aux_dec)
r.aux_ndec = iutil.copy(self.aux_ndec)
r.rv_in_causal = iutil.copy(self.rv_in_causal)
r.aux_ndec_try = iutil.copy(self.aux_ndec_try)
r.aux_ndec_force = iutil.copy(self.aux_ndec_force)
r.aux_borrow = iutil.copy(self.aux_borrow)
r.rv_ndec_force = iutil.copy(self.rv_ndec_force)
r.ndec_mode = self.ndec_mode
r.label = self.label
return r
def clear(self):
self.aux_out = None
self.aux_out_sublist = []
self.aux_dec = None
self.aux_ndec = None
self.aux_ndec_try = Comp.empty()
self.aux_ndec_force = Comp.empty()
self.aux_borrow = Comp.empty()
# self.rv_ndec_force = Comp.empty()
def record_to(self, index, skip_msg = False):
self.rv_out.record_to(index)
if self.aux_out is not None:
self.aux_out.record_to(index)
if self.aux_dec is not None:
self.aux_dec.record_to(index)
if self.aux_ndec is not None:
self.aux_ndec.record_to(index)
class CodingModel(IBaseObj):
def __init__(self, bnet = None, sublist = None, nodes = None, reg = None, partition = None):
if bnet is None:
self.bnet = BayesNet()
else:
if isinstance(bnet, BayesNet):
self.bnet = bnet
else:
self.bnet = BayesNet(bnet)
if reg is None:
self.reg = Region.universe()
else:
self.reg = reg
if sublist is None:
self.sublist = []
else:
self.sublist = sublist
self.sublist_rate = []
self.sublist_const = []
self.nodes = []
if nodes is None:
pass
else:
for node in nodes:
if isinstance(node, CodingNode):
self.nodes.append(node)
else:
self.nodes.append(CodingNode(node))
self.inner_mode = "plain"
self.bnet_out = None
self.bnet_in = None
self.bnet_out_arrays = None
self.partition = partition
self.use_union = True
self.aux_importance = []
self.aux_dummy = Comp.empty()
self.aux_nondummy = Comp.empty()
self.created_rv = Comp.empty()
def copy(self):
r = CodingModel()
r.bnet = iutil.copy(self.bnet)
r.reg = iutil.copy(self.reg)
r.sublist = [(a.copy(), b.copy()) for a, b in self.sublist]
r.sublist_rate = [(a.copy(), b.copy()) for a, b in self.sublist_rate]
r.sublist_const = [(a.copy(), b.copy()) for a, b in self.sublist_const]
r.nodes = [a.copy() for a in self.nodes]
r.inner_mode = self.inner_mode
r.bnet_out = iutil.copy(self.bnet_out)
r.bnet_in = iutil.copy(self.bnet_in)
r.partition = iutil.copy(self.partition)
r.use_union = self.use_union
r.aux_importance = iutil.copy(self.aux_importance)
r.aux_dummy = iutil.copy(self.aux_dummy)
r.aux_nondummy = iutil.copy(self.aux_nondummy)
r.created_rv = iutil.copy(self.created_rv)
return r
def __iadd__(self, other):
if isinstance(other, CodingNode):
self.nodes.append(other)
else:
self.bnet += other
return self
def __iand__(self, other):
if isinstance(other, Region):
self.reg = self.reg & other
return self
def add_edge(self, a, b, coded = False, children_edge = None, is_fcn = False,
rv_in_causal = None, rv_in_scausal = None, **kwargs):
if rv_in_causal is None:
rv_in_causal = Comp.empty()
if rv_in_scausal is None:
rv_in_scausal = Comp.empty()
ex = Comp.empty()
b2 = Comp.empty()
for c in b:
if self.bnet.index.get_index(c) >= 0:
ex += c
else:
b2 += c
ac = a.copy() + rv_in_causal - rv_in_scausal
acb = Comp.empty()
if not ex.isempty():
# cname = str(ex) + "?@@100@@#EX_"
# for c in a:
# cname += str(c)
# cname += "_"
# for c in ex:
# cname += str(c)
cname = str(ex) + "?"
cname += "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@"
cname += ex.tostring(style = PsiOpts.STR_STYLE_LATEX)
cname += "?[" + a.tostring(style = PsiOpts.STR_STYLE_LATEX) + "]"
exc = Comp.rv(cname)
self.created_rv += exc
self.bnet.add_edge(ac, exc)
if is_fcn:
self.bnet.set_fcn(exc)
# if children_edge:
# ac += ex
for c in ex:
if children_edge is True or (children_edge is None and not self.is_rate(c)):
acb += c
if coded:
self.nodes.append(CodingNode(exc, rv_in_causal = rv_in_causal,
rv_in_scausal = rv_in_scausal, **kwargs))
self.sublist.append((exc, ex))
for c in b2:
if children_edge is True or (children_edge is None and not self.is_rate(c)):
self.bnet.add_edge(ac + acb, c)
acb += c
else:
self.bnet.add_edge(ac, c)
if is_fcn:
self.bnet.set_fcn(c)
if coded:
self.nodes.append(CodingNode(c, rv_in_causal = rv_in_causal,
rv_in_scausal = rv_in_scausal, **kwargs))
def add_node(self, a, b, **kwargs):
self.add_edge(a, b, coded = True, **kwargs)
def add_source(self, a):
for i in range(1, len(a)):
self.add_edge(a[:i], a[i])
def set_rate(self, a, rate, dummy = False):
if isinstance(rate, (int, float)):
const_rate = rate
rate = Expr.real("#TMPRATE" + str(a))
self.sublist_const.append((rate, Expr.const(const_rate)))
self.sublist_rate.append((a, rate))
if not dummy:
self.aux_nondummy += a
def set_dummy(self, a):
self.set_rate(a, 0, dummy = True)
def get_aux(self):
r = Comp.empty()
for node in self.nodes:
r += node.aux_out
return r
def get_parents_noaux(self, x, bnet):
if not self.get_aux().ispresent(x):
return x.copy()
t = bnet.get_parents(x)
return sum((self.get_parents_noaux(y, bnet) for y in t), Comp.empty())
def get_node_rv_out_from_aux_dec(self, x):
r = Comp.empty()
for node in self.nodes:
if (node.aux_dec is not None and node.aux_dec.ispresent(x)) or (node.aux_ndec is not None and node.aux_ndec.ispresent(x)):
r += node.rv_out
return r
def is_aux_dummy(self, x):
return self.aux_dummy.ispresent(x)
def get_rv_rates(self, v):
v = v.copy()
r = Expr.zero()
for v0, v1 in self.sublist + self.sublist_rate:
if v0.ispresent(v):
if isinstance(v1, Expr):
r += v1
else:
v.substitute(v0, v1)
return r
def is_rate(self, v):
return not self.get_rv_rates(v).iszero()
def is_rate_zero(self, v):
if not self.is_rate(v):
return False
v = self.get_rv_rates(v)
for v0, v1 in self.sublist_const:
v.substitute(v0, v1)
v.simplify_quick()
return v.iszero()
def get_rate_rvs(self):
r = Comp.empty()
for a in self.bnet.index.comprv:
if self.is_rate(a):
r += a
return r
def is_src_rate(self, v):
if not self.is_rate(v):
return False
rate = self.get_rv_rates(v)
if not rate.isrealvar():
return False
if self.bnet.get_parents(v).isempty():
return False
count = 0
for v0, v1 in self.sublist + self.sublist_rate:
if v1.ispresent(v):
return False
if v1.ispresent(rate):
count += 1
if count >= 2:
return False
return True
def get_refs(self, v):
r = v.copy()
for a in self.bnet.index.comprv:
if self.get_rv_ratervs(a).ispresent(v):
r += a
return r
def is_ch_rate(self, v):
if not self.is_rate(v):
return False
rate = self.get_rv_rates(v)
if not rate.isrealvar():
return False
if not self.bnet.get_parents(v).isempty():
return False
g = self.get_refs(v) - v
for a in g:
if not self.bnet.get_children(a).isempty():
return False
return True
# def is_src_rate(self, v):
# if not self.is_rate(v):
# return False
# rate = self.get_rv_rates(v)
# if not rate.isrealvar():
# return False
# if self.find_node_rv_out(v) is None:
# return False
# return True
def get_rv_sub(self, v):
v = v.copy()
for v0, v1 in self.sublist + self.sublist_rate:
if v0.ispresent(v):
if isinstance(v1, Expr):
v.substitute(v0, v0[0])
# v = v0[0].copy()
else:
v.substitute(v0, v1)
return v
def get_rv_ratervs(self, v):
v = v.copy()
r = Comp.empty()
for v0, v1 in self.sublist + self.sublist_rate:
if v0.ispresent(v):
if isinstance(v1, Expr):
r += v0[0]
else:
v.substitute(v0, v1)
return r
def is_rv_original(self, v):
for v0, v1 in self.sublist + self.sublist_rate:
if v0.ispresent(v) and not isinstance(v1, Expr):
return False
return True
def get_node_immediate_rates(self, node):
x = node.rv_out + self.bnet.get_parents(node.rv_out)
return self.get_rv_rates(x)
def get_node_immediate_ratervs(self, node, original = False, include_ndec_force = False):
x = node.rv_out + self.bnet.get_parents(node.rv_out)
if original:
x = sum((y for y in x if self.is_rv_original(y)), Comp.empty())
if include_ndec_force:
x += node.rv_ndec_force
return self.get_rv_ratervs(x)
def is_balanced(self):
rvout = Comp.empty()
rvin = Comp.empty()
for node in self.nodes:
for a in node.rv_out:
if self.is_rate(a):
continue
rvout += rv(self.get_rv_sub(a))
for a in self.bnet.get_parents(node.rv_out):
if self.is_rate(a):
continue
rvin += rv(self.get_rv_sub(a))
return rvout.inter(rvin).isempty()
def get_node_descendants(self, node):
x = Comp.empty()
nodes = []
if isinstance(node, CodingNode):
nodes = [node]
else:
nodes = node
for node in nodes:
x += self.bnet.get_descendants(node.rv_out)
for node in nodes:
x -= node.rv_out
r = []
for node2 in self.nodes:
if x.ispresent(node2.rv_out):
r.append(node2)
return r
def get_node_dependent_after(self, node):
found = False
r = []
for node2 in self.nodes:
if found:
if not self.bnet.check_ic(Expr.I(node.rv_out, node2.rv_out)):
r.append(node2)
if node2 is node:
found = True
return r
def get_nodes_rv_out(self):
r = Comp.empty()
for node in self.nodes:
r += node.rv_out
return r
def find_node_rv_out(self, x):
for node in self.nodes:
if x.ispresent(node.rv_out):
return node
return None
def find_nodes_immediate_ratervs(self, a):
r = []
for node in self.nodes:
if not self.get_node_descendants(node):
continue
ratervs = self.get_node_immediate_ratervs(node)
if ratervs.ispresent(a):
r.append(node)
return r
def is_src(self, x):
if self.is_rate(x) or self.find_node_rv_out(x) is not None:
return False
for y in self.bnet.get_parents(x):
if not self.is_src(y):
return False
return True
def get_srcs(self):
r = Comp.empty()
for a in self.bnet.index.comprv:
if self.is_src(a):
r += a
return r
def chan_in_get_out(self, x):
if any(self.is_rate(a) for a in x):
return Comp.empty()
anc = sum((self.bnet.get_ancestors(y) for y in x), Comp.empty()) - x
r = x.copy()
did = True
while did:
did = False
for a in self.bnet.index.comprv - r - anc:
if self.is_rate(a) or self.find_node_rv_out(a) is not None:
continue
if not r.super_of(self.bnet.get_parents(a)):
continue
r += a
did = True
r -= x
return r
def get_chans(self):
r = []
for x in igen.subset(self.bnet.index.comprv, minsize=1):
y = self.chan_in_get_out(x)
if not y.isempty():
r.append((x, y))
return r
def rv_single_node(self, a):
nodes = self.find_nodes_immediate_ratervs(a)
node2 = self.find_node_rv_out(a)
c = len(nodes)
if node2 is not None and node2 not in nodes:
c += 1
return c <= 1
def remove_aux(self, x):
for node in self.nodes:
node.aux_out -= x
node.aux_dec -= x
node.aux_ndec -= x
node.aux_ndec_try -= x
node.aux_ndec_force -= x
node.aux_borrow -= x
def get_node_union_parents(self, nodes):
parents = None
for node in nodes:
t = self.bnet.get_parents(node.rv_out) - node.rv_in_causal
if parents is None:
parents = t
else:
parents = parents.inter(t)
return parents
def get_aux_union_parents(self, aux):
nodes = [node for node in self.nodes if node.aux_out.ispresent(aux) or node.aux_borrow.ispresent(aux)]
return self.get_node_union_parents(nodes)
def get_node_unions(self):
n = len(self.nodes)
vis = set()
r = []
for mask in range((1 << n) - 1, 0, -1):
nodes = [node for i, node in enumerate(self.nodes) if mask & (1 << i)]
parents = self.get_node_union_parents(nodes)
# print(mask)
# print(parents)
# print()
if parents is None or parents.isempty():
continue
parents_mask = self.bnet.index.get_mask(parents)
if parents_mask in vis:
continue
vis.add(parents_mask)
r.append(nodes)
return r
def get_index(self):
index = IVarIndex()
self.bnet.index.comprv.record_to(index)
for node in self.nodes:
if node.aux_out is not None:
node.aux_out.record_to(index)
for v0, v1 in self.sublist + self.sublist_rate:
v0.record_to(index)
v1.record_to(index)
return index
def name_avoid(self, name):
return self.get_index().name_avoid(name)
def calc_node_union(self, nodes):
inner_mode = self.inner_mode
for node in nodes:
if node.aux_out is None:
node.aux_out = Comp.empty()
ratervs = sum((self.get_node_immediate_ratervs(node, original = True) for node in nodes), Comp.empty())
rv_out = sum((node.rv_out for node in nodes), Comp.empty())
des = self.get_node_descendants(nodes)
ratervs = sum((c for c in ratervs if self.is_ch_rate(c) or any(self.get_node_immediate_ratervs(d, include_ndec_force = True).ispresent(c) for d in des)), Comp.empty())
# des = self.get_node_dependent_after(node)
crvs = list(ratervs)
if len(crvs) == 0:
crvs.append(Comp.empty())
inner_mode = "combinations"
if inner_mode == "combinations":
for c in crvs:
for mask in range(1, 1 << len(des)):
cname = "A_" + rv_out.tostring(style = PsiOpts.STR_STYLE_STANDARD)
# cname_l = ("A_{" + rv_out.tostring(style = PsiOpts.STR_STYLE_LATEX)
# + ("" if c.isempty() else "," + c.tostring(style = PsiOpts.STR_STYLE_LATEX)) + "}")
cname_l = "A_{" + rv_out.tostring(style = PsiOpts.STR_STYLE_LATEX)
if len(des) > 1:
cname_l += " \\to "
ccount = 0
for i, d in enumerate(des):
if mask & (1 << i):
cname += "_" + d.rv_out.tostring(style = PsiOpts.STR_STYLE_STANDARD)
if ccount:
cname_l += ","
ccount += 1
cname_l += d.rv_out.tostring(style = PsiOpts.STR_STYLE_LATEX)
cname_l += "}"
cname += ("" if c.isempty() else "_" + c.tostring(style = PsiOpts.STR_STYLE_STANDARD))
cname_l += ("" if c.isempty() else "^{" + c.tostring(style = PsiOpts.STR_STYLE_LATEX) + "}")
cname += "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + cname_l
cname = self.name_avoid(cname)
a = Comp.rv(cname)
# self.aux_importance.append((a, -iutil.bit))
for i, node in enumerate(nodes):
if i == 0:
node.aux_out += a
node.aux_out_sublist.append((a, a + c))
else:
node.aux_borrow += a
for i, d in enumerate(des):
if mask & (1 << i):
if d.aux_dec is None:
d.aux_dec = Comp.empty()
d.aux_dec += a
else:
d.aux_ndec_try += a
if d.rv_ndec_force.ispresent(c):
d.aux_ndec_force += a
elif inner_mode == "plain":
if len(des):
for c in crvs:
single_node = self.rv_single_node(c)
if single_node:
cname = ("A_" + c.tostring(style = PsiOpts.STR_STYLE_STANDARD))
else:
cname = ("A_" + rv_out.tostring(style = PsiOpts.STR_STYLE_STANDARD)
+ ("" if c.isempty() or len(crvs) == 1 else "_" + c.tostring(style = PsiOpts.STR_STYLE_STANDARD)))
cname += "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@"
# print(str(c) + " " + str(len(self.find_nodes_immediate_ratervs(c))))
if single_node:
cname += ("A_{" + c.tostring(style = PsiOpts.STR_STYLE_LATEX) + "}")
else:
cname += ("A_{" + rv_out.tostring(style = PsiOpts.STR_STYLE_LATEX) + "}"
+ ("" if c.isempty() or len(crvs) == 1 else "^{" + c.tostring(style = PsiOpts.STR_STYLE_LATEX) + "}"))
cname = self.name_avoid(cname)
a = Comp.rv(cname)
if self.is_rate_zero(c) and not self.aux_nondummy.ispresent(c):
self.aux_dummy += a
for i, node in enumerate(nodes):
if i == 0:
node.aux_out += a
node.aux_out_sublist.append((a, a + c))
else:
# pass
node.aux_borrow += a
for d in des:
if self.get_node_immediate_ratervs(d).ispresent(c):
if d.aux_dec is None:
d.aux_dec = Comp.empty()
d.aux_dec += a
else:
d.aux_ndec_try += a
if d.rv_ndec_force.ispresent(c):
d.aux_ndec_force += a
def calc_node(self, node):
inner_mode = self.inner_mode
if node.aux_out is None:
node.aux_out = Comp.empty()
ratervs = self.get_node_immediate_ratervs(node)
des = self.get_node_descendants(node)
# des = self.get_node_dependent_after(node)
crvs = list(ratervs)
if len(crvs) == 0:
crvs.append(Comp.empty())
inner_mode = "combinations"
if inner_mode == "combinations":
for c in crvs:
for mask in range(1, 1 << len(des)):
cname = "A_" + node.rv_out.tostring(style = PsiOpts.STR_STYLE_STANDARD)
# cname_l = ("A_{" + node.rv_out.tostring(style = PsiOpts.STR_STYLE_LATEX)
# + ("" if c.isempty() else "," + c.tostring(style = PsiOpts.STR_STYLE_LATEX)) + "}")
cname_l = "A_{" + node.rv_out.tostring(style = PsiOpts.STR_STYLE_LATEX)
for i, d in enumerate(des):
if mask & (1 << i):
cname += "_" + d.rv_out.tostring(style = PsiOpts.STR_STYLE_STANDARD)
cname_l += "," + d.rv_out.tostring(style = PsiOpts.STR_STYLE_LATEX)
cname_l += "}"
cname += ("" if c.isempty() else "_" + c.tostring(style = PsiOpts.STR_STYLE_STANDARD))
cname_l += ("" if c.isempty() else "^{" + c.tostring(style = PsiOpts.STR_STYLE_LATEX) + "}")
cname += "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + cname_l
cname = self.name_avoid(cname)
a = Comp.rv(cname)
# self.aux_importance.append((a, -iutil.bit))
node.aux_out += a
node.aux_out_sublist.append((a, a + c))
for i, d in enumerate(des):
if mask & (1 << i):
if d.aux_dec is None:
d.aux_dec = Comp.empty()
d.aux_dec += a
else:
d.aux_ndec_try += a
if d.rv_ndec_force.ispresent(c):
d.aux_ndec_force += a
elif inner_mode == "plain":
if len(des):
for c in crvs:
single_node = self.rv_single_node(c)
if single_node:
cname = ("A_" + c.tostring(style = PsiOpts.STR_STYLE_STANDARD))
else:
cname = ("A_" + node.rv_out.tostring(style = PsiOpts.STR_STYLE_STANDARD)
+ ("" if c.isempty() or len(crvs) == 1 else "_" + c.tostring(style = PsiOpts.STR_STYLE_STANDARD)))
cname += "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@"
# print(str(c) + " " + str(len(self.find_nodes_immediate_ratervs(c))))
if single_node:
cname += ("A_{" + c.tostring(style = PsiOpts.STR_STYLE_LATEX) + "}")
else:
cname += ("A_{" + node.rv_out.tostring(style = PsiOpts.STR_STYLE_LATEX) + "}"
+ ("" if c.isempty() or len(crvs) == 1 else "^{" + c.tostring(style = PsiOpts.STR_STYLE_LATEX) + "}"))
cname = self.name_avoid(cname)
a = Comp.rv(cname)
node.aux_out += a
node.aux_out_sublist.append((a, a + c))
for d in des:
if self.get_node_immediate_ratervs(d).ispresent(c):
if d.aux_dec is None:
d.aux_dec = Comp.empty()
d.aux_dec += a
else:
d.aux_ndec_try += a
if d.rv_ndec_force.ispresent(c):
d.aux_ndec_force += a
def calc_node_unions(self):
if all(node.aux_out is not None for node in self.nodes):
return
unions = self.get_node_unions()
for node in self.nodes:
if node.aux_out is None:
node.aux_out = Comp.empty()
for nodes in unions:
self.calc_node_union(nodes)
def calc_nodes(self):
for node in self.nodes:
self.calc_node(node)
def clear_cache(self):
for node in self.nodes:
node.clear()
def presolve(self):
self.sublist += self.sublist_rate
self.sublist_rate = []
def get_region(self, oneshot = False):
r = self.bnet.copy()
tosublist = []
for node in self.nodes:
tosublist += node.aux_out_sublist
tosublist += self.sublist + self.sublist_rate
convexify = None
if not oneshot:
convexify = Comp.rv("#TMPVAR_GET_REGION")
for node in self.nodes:
r += (convexify, node.rv_out)
r = r.contracted_node(convexify)
for v0, v1 in tosublist:
if isinstance(v1, Expr) and not isinstance(v0, Expr):
r = r.contracted_node(v0)
else:
r = r.contracted_node(v0)
if self.partition is not None:
reg = Region.universe()
for i, a in enumerate(self.partition):
a2 = sum((b for j, b in enumerate(self.partition) if j != i), Comp.empty())
reg &= r.eliminated(a2).get_region()
return reg
return r.get_region()
def is_netcode(self):
"""Whether this setting is a network coding setting, i.e., all random variables are messages with rates.
"""
for a in self.bnet.index.comprv:
if not self.is_rate(a):
return False
return True
def get_netcode_region(self, convexify = None, mi = True, skip_simplify = False):
"""Rate region of network coding setting.
If mi = False, use the Yan-Yeung-Zhang region (CAUTION: Convexification and closure is not performed):
X. Yan, R. W. Yeung, and Z. Zhang, "An implicit characterization of the achievable rate region for
acyclic multisource multisink network coding," IEEE Trans. Inf. Theory, vol. 58, no. 9, pp. 5625-5639, 2012.
"""
self.presolve()
if self.bnet.tsorted() is None:
raise ValueError("Non-strictly-causal cycle detected.")
return None
if convexify is None:
convexify = True
if convexify is not False:
if convexify is True:
convexify = Comp.index("Q_i")
if not isinstance(convexify, Comp):
convexify = Comp.empty()
r = Region.universe()
msgs = Comp.empty()
for a in self.bnet.index.comprv:
pa = self.bnet.get_parents(a)
if pa.isempty():
if not mi:
r &= Expr.I(msgs, a) == 0
msgs += a
rate = self.get_rv_rates(a)
if not mi:
r &= Expr.H(a) >= rate
r &= rate >= 0
else:
a2 = self.get_rv_ratervs(a)
if a == a2:
rate = self.get_rv_rates(a)
if mi:
r &= Expr.I(a, pa) <= rate
else:
r &= Expr.Hc(a, pa) == 0
r &= Expr.H(a) <= rate
else:
for a3 in a2:
rate = self.get_rv_rates(a3)
if mi:
r &= Expr.I(a3, pa) >= rate
else:
r &= Expr.Hc(a3, pa) == 0
if mi:
tbnet = self.bnet.copy()
for a in tbnet.index.comprv:
tbnet.set_fcn(a, False)
a2 = self.get_rv_ratervs(a)
if a != a2:
tbnet = tbnet.contracted_node(a)
r &= tbnet.get_region()
if False:
tosublist = []
for node in self.nodes:
tosublist += node.aux_out_sublist
tosublist += self.sublist
for v0, v1 in tosublist:
if not (isinstance(v1, Expr) and not isinstance(v0, Expr)):
r.substitute(v0, v1)
for tozero, toconst in self.sublist_const:
r.substitute(tozero, toconst)
r = r.exists(self.bnet.index.comprv.inter(r.allcomp()))
if not skip_simplify:
r.simplify_aux_commonpart(maxlen = 3, forbid_h = False)
r = r.simplified()
return r
def src_combine(self):
r = []
groups = []
for a in self.bnet.index.comprv:
if self.is_src_rate(a):
p = self.bnet.get_parents(a)
c = self.bnet.get_children(a)
for g in groups:
p2 = p.copy()
for a2, c2 in g[1]:
p2 -= a2
if g[0] == p2:
g[1].append((a, c))
break
else:
groups.append((p, [(a, c)]))
for p, g in groups:
mask_set = set()
for a, c in g:
mask_set.add(self.bnet.index.get_mask(c))
mask_list = list(mask_set)
for mmask in range(1, 1 << len(mask_list)):
cmask = None
for i, mask in enumerate(mask_list):
if not mmask & (1 << i):
continue
if cmask is None:
cmask = mask
else:
cmask |= mask
if cmask == 0 or cmask in mask_set:
continue
mask_set.add(cmask)
cc = self.bnet.index.from_mask(cmask)
cg = [a for a, c in g if cc.super_of(c)]
# cname = "(" + str(sum(cg)) +")" + "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + "(" + sum(cg).tostring(style = "latex") + ")"
cname = str(sum(cg)) + "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + sum(cg).tostring(style = "latex")
cname = self.name_avoid(cname)
crv = Comp.rv(cname)
self.set_dummy(crv)
self.add_node(p, crv)
for cc2 in cc:
self.bnet.add_edge(crv, cc2)
self.bnet = self.bnet.tsorted()
return self
def ch_combine(self, min_ndummy = 0):
groups = []
for a in self.bnet.index.comprv:
if self.is_ch_rate(a):
p = self.bnet.get_children(a)
for g in groups:
if g[0] == p:
g[1].append(a)
break
else:
groups.append((p, [a]))
for p, g in groups:
decs = []
for i, a in enumerate(g):
for b in self.get_refs(a) - a:
bp = self.bnet.get_parents(b)
if bp not in decs:
decs.append(bp)
# Whether y is degraded compared to x
deg = [[self.bnet.check_ic(Expr.Ic(p, x, y)) for y in decs] for x in decs]
idecs = [0 for i in range(len(g))]
for i, a in enumerate(g):
for b in self.get_refs(a) - a:
bp = self.bnet.get_parents(b)
bpi = decs.index(bp)
for j in range(len(decs)):
if deg[bpi][j]:
idecs[i] |= 1 << j
# print(idecs)
dec_vis = list(idecs)
ndummy = 0
for mask in range(1, 1 << len(g)):
cdec = 0
cg = Comp.empty()
for i, a in enumerate(g):
if mask & (1 << i):
cdec |= idecs[i]
cg += a
# print(str(mask) + " " + str(cdec))
if cdec in dec_vis:
continue
# cname = "(" + str(sum(g)) + ")" + "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + "(" + sum(g).tostring(style = "latex") + ")"
cname = str(sum(cg, Comp.empty())) + "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + sum(cg, Comp.empty()).tostring(style = "latex")
cname = self.name_avoid(cname)
crv = Comp.rv(cname)
# crate = Expr.real("MRATE_" + str(p) + "_" + str(mask))
# self.set_rate(crv, crate)
self.set_dummy(crv)
for pb in p:
self.bnet.add_edge(crv, pb)
# self.add_edge(crv, p, children_edge = False)
# for j in range(len(decs)):
# if cdec & (1 << j):
# self.add_node(decs[j], crv)
ndummy += 1
for i, a in enumerate(g):
if idecs[i] | cdec == cdec:
# print("TRY " + str(a) + " " + str(crv))
for b in self.get_refs(a) - a:
node = self.find_node_rv_out(b)
if node is not None:
node.rv_ndec_force += crv
# print("ADD " + str(node.rv_out) + " " + str(crv))
# self.sublist_zero.append(crate)
dec_vis.append(cdec)
for i in range(min_ndummy - ndummy):
cname = str(sum(g, Comp.empty())) + "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + sum(g, Comp.empty()).tostring(style = "latex")
cname = self.name_avoid(cname)
crv = Comp.rv(cname)
self.set_dummy(crv)
for pb in p:
self.bnet.add_edge(crv, pb)
self.bnet = self.bnet.tsorted()
return self
def get_src_splice(self):
r = []
groups = []
for a in self.bnet.index.comprv:
if self.is_src_rate(a):
p = self.bnet.get_parents(a)
c = self.bnet.get_children(a)
for g in groups:
p2 = p.copy()
for a2, c2 in g[1]:
p2 -= a2
if g[0] == p2:
g[1].append((a, c))
break
else:
groups.append((p, [(a, c)]))
def covered(g, i, mask):
ci = g[i][1]
cj = Comp.empty()
for j in range(len(g)):
if mask & (1 << j):
cj += g[j][1]
return cj.super_of(ci)
for p, g in groups:
for i in range(len(g)):
ci = g[i][1]
for mask in igen.subset_mask((1 << len(g)) - 1 - (1 << i)):
if not covered(g, i, mask):
continue
bad = False
for j in range(len(g)):
if mask & (1 << j) and covered(g, i, mask - (1 << j)):
bad = True
break
if bad:
continue
r.append((self.get_rv_rates(g[i][0]),
[self.get_rv_rates(g[j][0]) for j in range(len(g)) if mask & (1 << j)]))
return r
def get_inner(self, convexify = None, ndec_mode = "all", skip_simplify = False, skip_aux = False, splice = True, combine = True, is_proof = False, num_dummy = 0, target = None):
"""Get an inner bound of the capacity region via [Lee-Chung 2015] together with simplification procedures.
Si-Hyeon Lee, and Sae-Young Chung. "A unified approach for network information theory."
2015 IEEE International Symposium on Information Theory (ISIT). IEEE, 2015.
"""
if target is not None:
skip_aux = False
eliminate_quick = False
if self.is_netcode():
return self.get_netcode_region(convexify = convexify, skip_simplify = skip_simplify)
if combine:
cs = self.copy()
cs.ch_combine(num_dummy)
cs.src_combine()
return cs.get_inner(convexify, ndec_mode, skip_simplify, skip_aux, splice, False, is_proof, 0, target)
verbose = PsiOpts.settings.get("verbose_codingmodel", False)
if verbose:
print("============== Get Inner ==============")
self.presolve()
if self.bnet.tsorted() is None:
raise ValueError("Non-strictly-causal cycle detected.")
return None
if convexify is None:
convexify = self.convexify_needed()
if convexify is not False:
if convexify is True:
convexify = Comp.index("Q_i")
if self.use_union:
self.calc_node_unions()
else:
self.calc_nodes()
for node in self.nodes:
if node.aux_ndec is None:
r = RegionOp.union([])
cndec_mode = node.ndec_mode
if cndec_mode is None:
cndec_mode = ndec_mode
if verbose:
print("============== NDec ==============")
print("Mode : " + cndec_mode)
print("NDec try : " + str(node.aux_ndec_try))
print("NDec force: " + str(node.aux_ndec_force))
if cndec_mode == "all":
# print("NODE " + str(node.rv_out) + " " + str(node.aux_ndec_force))
for cndec0 in igen.subset(node.aux_ndec_try - node.aux_ndec_force):
cndec = node.aux_ndec_force + cndec0
# print(str(node.rv_out) + " " + str(cndec))
node.aux_ndec = cndec if isinstance(cndec, Comp) else Comp.empty()
t = self.get_inner(convexify, ndec_mode = ndec_mode, skip_simplify = True, skip_aux = True, splice = splice, combine = combine, is_proof = is_proof, num_dummy = num_dummy, target = target)
if t is not None:
if target is not None:
r = t
break
r |= t
elif cndec_mode == "min" or cndec_mode == "none":
# node.aux_ndec = Comp.empty()
node.aux_ndec = node.aux_ndec_force.copy()
r = self.get_inner(convexify, ndec_mode = ndec_mode, skip_simplify = True, skip_aux = True, splice = splice, combine = combine, is_proof = is_proof, num_dummy = num_dummy, target = target)
elif cndec_mode == "max":
node.aux_ndec = node.aux_ndec_try.copy()
r = self.get_inner(convexify, ndec_mode = ndec_mode, skip_simplify = True, skip_aux = True, splice = splice, combine = combine, is_proof = is_proof, num_dummy = num_dummy, target = target)
node.aux_ndec = None
if r is None:
return None
if target is not None:
return r
if not skip_aux:
aux = self.get_aux()
caux = aux.copy()
if convexify is not False:
caux += convexify
r = r.eliminate(caux.inter(r.allcomprv()), quick = eliminate_quick)
if skip_simplify:
return r
else:
r.remove_missing_aux()
r.simplify()
r.simplify_union()
# print(r)
r.remove_missing_aux()
r = r.simplified_quick()
if verbose:
print("============== NDec Simplified ==============")
print(r)
return r
aux = self.get_aux()
# aux_rates = [Expr.real("#TR" + str(i)) for i in range(len(aux))]
cindex = self.get_index()
aux_rates = []
for a in aux:
cname = iutil.name_prefix_suffix(a.get_name(), "R_{", "}")
cname = cindex.name_avoid(cname)
aux_rates.append(Expr.real(cname))
aux_rates[-1].record_to(cindex)
if verbose:
print("============== Rates ==============")
for a, ar in zip(aux, aux_rates):
print(str(a) + ": " + str(ar))
cbnet = self.bnet.copy()
r = Region.universe()
for rate in aux_rates:
r &= rate >= 0
nodes_code = [[] for node in self.nodes]
for node, node_code in zip(self.nodes, nodes_code):
y = self.bnet.get_parents(node.rv_out) - node.rv_in_causal
aux_dec = Comp.empty()
if node.aux_dec is not None:
aux_dec = node.aux_dec
auxlist = [(auxx, self.get_aux_union_parents(auxx)) for auxx in node.aux_out]
# auxlist.sort(key = lambda auxx, auxp: )
for i, (auxx, auxp) in enumerate(auxlist):
caux = Comp.empty()
for (aux2, aux2p) in auxlist[:i]:
if auxp.super_of(aux2p):
caux += aux2
cbnet.add_edge(aux_dec + auxp + node.aux_borrow + caux, auxx)
cbnet.add_edge(aux_dec + y + node.aux_borrow + node.aux_out, node.rv_out)
# Decoding constraints
dec_inp = y + node.aux_borrow
for cdec in igen.subset(aux_dec, minsize = 1):
for cndec in igen.subset(node.aux_ndec):
cd = cdec + cndec
expr = Expr.zero()
cc = Comp.empty()
for c in cd:
ci = aux.index_of(c)
expr += aux_rates[ci]
expr -= I(c & cc + (aux_dec + node.aux_ndec - cd) + dec_inp)
cc += c
pf_note = ["dec ", dec_inp.copy(), " to ", "#CMD_EXCLUDE_PREV", cdec.copy()]
if not cndec.isempty():
pf_note += [" ", "nonuniq ", cndec.copy()]
r &= (expr <= 0).add_meta("pf_note", pf_note)
if not aux_dec.isempty():
node_code.append({
"mode": "decode",
"inp": dec_inp.copy(),
"oup": aux_dec.copy(),
"ndec": node.aux_ndec.copy()
})
# Encoding constraints
for i, (auxx, auxp) in enumerate(auxlist):
caux = Comp.empty()
caux_out = Comp.empty()
bad = False
for j, (aux2, aux2p) in enumerate(auxlist):
if auxp == aux2p:
if j > i:
bad = True
break
caux_out += aux2
elif j < i and auxp.super_of(aux2p):
caux += aux2
if bad:
continue
enc_inp = aux_dec + auxp + node.aux_borrow + caux
for ce in igen.subset(caux_out, minsize = 1):
expr = Expr.zero()
cc = Comp.empty()
for c in ce:
ci = aux.index_of(c)
expr += aux_rates[ci]
expr -= I(c & cc + enc_inp)
cc += c
# if self.is_rate(auxp):
# r &= (expr >= 0)
# else:
pf_note = ["enc ", enc_inp.copy(), " to ", "#CMD_EXCLUDE_PREV", ce.copy()]
r &= (expr >= 0).add_meta("pf_note", pf_note).add_meta("pf_note_priority", -1)
if not caux_out.isempty():
node_code.append({
"mode": "encode",
"inp": enc_inp.copy(),
"oup": caux_out.copy()
})
if not node.rv_out.isempty():
node_code.append({
"mode": "generate",
"inp": sum((cbnet.get_parents(x) for x in node.rv_out), Comp.empty()),
"oup": node.rv_out.copy()
})
# Check dummy duplicates
if verbose:
print("============== Check Dummy ==============")
for a in aux:
if not self.is_aux_dummy(a):
continue
a_parents = self.get_parents_noaux(a, cbnet)
a_rvout = self.get_node_rv_out_from_aux_dec(a)
if verbose:
print(str(a) + ": " + str(a_parents) + " -> " + str(a_rvout))
if a_rvout.isempty():
return None
for b in aux - a:
# if not self.is_aux_dummy(b):
# continue
if a_parents == self.get_parents_noaux(b, cbnet):
b_rvout = self.get_node_rv_out_from_aux_dec(b)
if a_rvout == b_rvout:
return None
if not self.is_aux_dummy(b) and b_rvout.super_of(a_rvout):
return None
if verbose:
print("============== Check Dummy End ==============")
if convexify is not False:
r.condition(convexify)
for node in self.nodes:
cbnet.add_edge(convexify, node.aux_out + node.rv_out)
self.bnet_in = cbnet
r &= cbnet.get_region()
allrv = r.allcomprv()
if is_proof:
r.add_meta("inner_rv_list", CompArray(list(allrv)), children = False)
tosublist = []
for node in self.nodes:
tosublist += node.aux_out_sublist
tosublist += self.sublist
if verbose:
print("============== Nodes ==============")
for node in self.nodes:
print(str(self.bnet.get_parents(node.rv_out))
+ (", (borrow " + str(node.aux_borrow) + ")" if not node.aux_borrow.isempty() else "")
+ " -> " + str(node.rv_out) + ", aux "
+ str(node.aux_out) + ", dec " + str(node.aux_dec)
+ (", (ndec " + str(node.aux_ndec) + ")" if not node.aux_ndec.isempty() else ""))
print("============== Bayes net ==============")
print(cbnet)
print("============== Aux ==============")
for a in aux:
print(str(a) + (" (dummy)" if self.is_aux_dummy(a) else "") + ": " + str(self.get_parents_noaux(a, cbnet)) + " -> " + str(self.get_node_rv_out_from_aux_dec(a)))
print("============== Inner bound ==============")
# print(r)
r.print(note=True)
with PsiOpts(meta_subs_criteria = "eqtype"):
for v0, v1 in tosublist:
r.substitute(v0, v1)
if isinstance(v1, Expr) and not isinstance(v0, Expr):
r &= v1 >= 0
# if iutil.check_meta_subs_criteria(v0, v1, self):
# for node, node_code in zip(self.nodes, nodes_code):
# iutil.substitute(node_code, v0, v1)
if verbose:
print("============== Substitute ==============")
for v0, v1 in tosublist:
print(str(v0) + " -> " + str(v1))
print("============== After substitute ==============")
# print(r)
r.print(note=True)
reg_record = None
if is_proof:
reg_record = []
r = r.eliminate(sum(aux_rates, Expr.zero()), quick = eliminate_quick, reg_record = reg_record)
# if convexify is not False:
# r.condition(convexify)
# rallcomprv = r.allcomprv()
# encout = self.get_nodes_rv_out().inter(rallcomprv)
# r &= markov(rallcomprv - encout - convexify, encout, convexify)
caux = aux.copy()
if convexify is not False:
caux += convexify
# r_prior = self.bnet.get_region().copy()
# assume_reg = Region.universe()
if not self.reg.isuniverse():
# if eliminate_quick and not skip_simplify:
# r = r.simplified()
if eliminate_quick:
r = r.simplified()
# r = r & self.reg
# r = r.and_cause_consequence(self.reg, avoid = caux)
# r = r.and_cause_consequence(self.reg, added_reg = r_prior)
r = r.and_cause_consequence(self.reg)
if self.partition is not None:
r &= indep(*(self.partition)).add_meta("pf_note", ["partition"])
if not skip_aux:
r = r.eliminate(caux.inter(r.allcomprv()), quick = eliminate_quick)
if verbose:
print("============== Elim aux rate ==============")
# print(r)
r.print(note=True)
if splice:
splice_list = self.get_src_splice()
for w0, w1 in splice_list:
r.splice_rate(w0, w1)
if verbose:
print("============== After splice ==============")
# print(r)
r.print(note=True)
for tozero, toconst in self.sublist_const:
r.substitute(tozero, toconst)
with PsiOpts(meta_subs_criteria = "eqtype"):
for a in cbnet.index.comprv:
if self.is_rate_zero(a):
r.substitute(a, Comp.empty())
did_simplify = False
if target is not None:
aux_assign = (target >> r).check_getaux()
if aux_assign is None:
return None
with PsiOpts(meta_subs_criteria = "eqtype"):
Comp.substitute_list(r, aux_assign, isaux = True)
r = r.eliminate(target.aux.inter(r.allcomprv()), quick = eliminate_quick)
r = r.simplified()
did_simplify = True
# print(r)
if skip_simplify:
pass
# return r
else:
# r = r.simplified(reg = r_prior)
r = r.simplified()
# r = r.simplified(self.reg)
r.remove_missing_aux()
did_simplify = True
if verbose:
print("============== Simplified ==============")
# print(r)
r.print(note=True)
if is_proof:
if skip_aux or not did_simplify:
if skip_aux:
r = r.eliminate(caux.inter(r.allcomprv()), quick = eliminate_quick)
# with PsiOpts(meta_subs_criteria = "eqtype"):
# r = r.simplified()
r = r.simplified()
# r = r.simplified(self.reg)
r.remove_missing_aux()
if skip_aux:
for c in caux:
r.substitute_aux(c, c)
did_simplify = True
if verbose:
print("============== Simplified (proof) ==============")
# print(r)
r.print(note=True)
allrv_after = r.get_meta("inner_rv_list")
for a in self.bnet.index.comprv + aux + allrv_after.allcomprv():
if self.is_rate_zero(a):
iutil.substitute(allrv_after, a, Comp.empty())
for node, node_code in zip(self.nodes, nodes_code):
iutil.substitute(node_code, a, Comp.empty())
if verbose:
print("============== rv after ==============")
print(allrv_after)
rnote = []
# print(r)
# print(repr(aux_after))
# print()
rate_reg_mode = "one"
def postprocess(x):
for a, b in zip(allrv, allrv_after):
iutil.substitute(x, a, b.copy())
if isinstance(x, Region):
for v0, v1 in tosublist:
x.substitute(v0, v1)
for tozero, toconst in self.sublist_const:
x.substitute(tozero, toconst)
rates_wanted = Expr.zero()
for a, b in zip(allrv, allrv_after):
if aux.ispresent(a) and not b.isempty():
crate = aux_rates[aux.index_of(a)]
rates_wanted += crate
# print(rates_wanted)
# print(reg_record)
# print()
Region.eliminate_reg_record_clean(reg_record, rates_wanted, cross=True)
# print(reg_record)
# print()
for node, node_code in zip(self.nodes, nodes_code):
for ccode in node_code:
if "oup" in ccode:
ccode["oup_cb"] = ccode["oup"].copy()
ccode["oup_id"] = ccode["oup"].copy()
if "ndec" in ccode:
ccode["ndec_cb"] = ccode["ndec"].copy()
ccode["ndec_id"] = ccode["ndec"].copy()
cnote = ["Codebook:"]
rnote.append(cnote)
cnaux = 0
for a, b in zip(allrv, allrv_after):
b_cb = b.copy()
b_id = b.copy()
if aux.ispresent(a):
id_name = ""
if len(str(cnaux + 1)) >= 2:
id_name = "i_{" + str(cnaux + 1) + "}"
else:
id_name = "i_" + str(cnaux + 1)
b_cb = sum((Comp.rv(iutil.name_prefix_suffix(x.get_name(), "", "[" + id_name + "]")) for x in b), Comp.empty())
b_id = Comp.rv(id_name)
if not b.isempty():
crate = aux_rates[aux.index_of(a)]
found = False
if rate_reg_mode == "detail":
cnote = [" " * 2, b_cb, ",", " ", "rate = ", crate]
rnote.append(cnote)
for reg_var, reg_r in reg_record:
if reg_var.get_name() == crate.get_name():
reg_r2 = reg_r.copy()
reg_r2.remove_meta("pf_note")
postprocess(reg_r2)
reg_r2.simplify(reg = r)
if reg_r2.isuniverse():
continue
if rate_reg_mode == "one":
one_bound = reg_r2.get_one_bound(crate, 2)
if one_bound is None:
continue
cnote = [" " * 2, b_cb, ",", " ", "rate = ", one_bound]
rnote.append(cnote)
found = True
elif rate_reg_mode == "detail":
cnote = [" " * 4, reg_r2]
rnote.append(cnote)
if rate_reg_mode == "one" and not found:
cnote = [" " * 2, b_cb]
rnote.append(cnote)
cnaux += 1
# print(a)
# print(b)
# print(b_cb)
# print(b_id)
# print()
for node, node_code in zip(self.nodes, nodes_code):
for ccode in node_code:
for key, value in ccode.items():
if key in ["oup_cb", "ndec_cb"]:
iutil.substitute(value, a, b_cb.copy())
elif key in ["oup_id", "ndec_id"]:
iutil.substitute(value, a, b_id.copy())
else:
iutil.substitute(value, a, b.copy())
for node, node_code in zip(self.nodes, nodes_code):
nline = 0
for ccode in node_code:
cnote = []
if ccode["mode"] in ("encode", "decode"):
if ccode["oup"].isempty():
continue
if nline == 0:
if node.label is not None and node.label != "":
cnote += [node.label, " "]
else:
cnote += [" " * 4]
# cnote += ["decodes", " ", ccode["inp"], " ", "to", " ", ccode["oup"]]
cnote += ["finds", " ", ccode["oup_id"], ":", " "]
if "ndec" in ccode and not ccode["ndec"].isempty():
cnote += ["exists", " ", ccode["ndec_id"], ":", " "]
cnote += ["(", ccode["inp"]]
if "ndec" in ccode and not ccode["ndec"].isempty():
cnote += [",", ccode["ndec_cb"]]
cnote += [",", ccode["oup_cb"], ")", "is typical"]
rnote.append(cnote)
nline += 1
elif ccode["mode"] == "generate":
oup = ccode["oup"] - ccode["inp"]
inp = ccode["inp"] - self.get_rate_rvs()
if oup.isempty():
continue
if nline == 0:
if node.label is not None and node.label != "":
cnote += [node.label, " "]
else:
cnote += [" " * 4]
cnote += ["generates", " ", oup, " ", "from", " ", inp]
rnote.append(cnote)
nline += 1
# print(repr(rnote))
r.add_meta("pf_note_pre", rnote, children = False)
return r
def nfold(self, n = 2, natural_eqprob = True, all_eqprob = True, ret_concmodel = None, first_noindex = True):
"""
Return the n-letter setting.
Parameters
----------
n : int
The blocklength.
Returns
-------
CodingModel.
"""
r = self.copy()
bnet_out = BayesNet()
reg_out_map = {}
clen = n
toeqprob = Comp.empty()
for x in self.bnet.allcomp() + self.reg.allcomprv():
t = None
if self.is_rate(x):
t = x.copy()
else:
t = rv_array(x.get_name(), n)
if first_noindex:
t[0] = x.copy()
if natural_eqprob or all_eqprob:
toeqprob += x
reg_out_map[x.get_name()] = t
for (a, b) in self.bnet.edges():
if not all_eqprob:
toeqprob -= b
am = reg_out_map[a.get_name()]
bm = reg_out_map[b.get_name()]
node = self.find_node_rv_out(b)
if node is None:
bnet_out += (am, bm)
else:
if node.rv_in_causal.ispresent(a) and isinstance(am, CompArray) and isinstance(bm, CompArray):
for t1 in range(clen):
for t2 in range(t1 + 1, clen):
bnet_out += (am[t1]+bm[t1], bm[t2])
else:
bnet_out += (am.allcomp(), bm.allcomp())
for x in self.bnet.allcomp():
if self.bnet.is_fcn(x):
xm = reg_out_map[x.get_name()]
bnet_out.set_fcn(xm)
r.bnet = bnet_out
r.reg = Region.universe()
for i in range(n):
tr = self.reg.copy()
for x in self.reg.allcomprv():
v = reg_out_map[x.get_name()]
if isinstance(v, CompArray):
tr.substitute(x, v[i])
r.reg = r.reg & tr
if not toeqprob.isempty():
for i in range(1, n):
v0 = CompArray.empty()
vi = CompArray.empty()
for x in toeqprob:
v0.append(reg_out_map[x.get_name()][0])
vi.append(reg_out_map[x.get_name()][i])
r.reg = r.reg & (ent_vector(*v0) == ent_vector(*vi))
def map_getlist_id(a, i):
if isinstance(a, Expr):
return a.copy()
t = None
for c in a:
cm = reg_out_map[c.get_name()]
if isinstance(cm, CompArray):
cm = cm[i]
if t is None:
t = cm.copy()
else:
t += cm
return t
r.sublist = []
for i in range(n):
for v0, v1 in self.sublist:
r.sublist.append((map_getlist_id(v0, i), map_getlist_id(v1, i)))
r.nodes = []
for i in range(n):
for node in self.nodes:
v = reg_out_map[node.rv_out.get_name()]
if isinstance(v, CompArray):
v = v[i]
else:
if i != 0:
continue
cnode = node.copy()
cnode.clear()
cnode.rv_out = v.copy()
r.nodes.append(cnode)
if ret_concmodel is not None:
cmodel = r.bnet.example_bsc()
convexify = Comp.index("Q_o").set_card(n)
cmodel[convexify] = ConcDist.uniform(n)
for x in self.bnet.allcomp():
p = self.bnet.get_parents(x)
t = None
if self.is_rate(x):
t = x.copy()
else:
t = rv_array(x.get_name(), n)
t[0] = x.copy()
if natural_eqprob or all_eqprob:
toeqprob += x
t = reg_out_map[x.get_name()]
if isinstance(t, CompArray):
s = CompArray.series_sym(x)
cmodel.set_series(s, convexify, t)
# cmodel[s["P"] | convexify + x.get_sum()] = lambda a:
ret_concmodel.append(cmodel)
return r
def nfold_concmodel(self, n = 2, **kwargs):
r = []
t = self.nfold(n, first_noindex = False, ret_concmodel = r, **kwargs)
return r[0]
def convexify_needed(self):
return self.get_outer(convexify = False, convexify_test = True)
def get_outer(self, aux = None, oneshot = False, future = True, convexify = None,
add_csiszar_sum = True, leaf_remove = True, node_fcn = True,
node_fcn_force = False, skip_simplify = True, convexify_test = False,
is_proof = False, include_nondecode_series = False, include_last_future = False,
remove_created = True, full = False, balanced = None, dual_bound_values = None,
latent_iid = False, **kwargs):
"""Get an outer bound of the capacity region.
"""
if full:
future = True
add_csiszar_sum = True
include_nondecode_series = True
include_last_future = True
decoding_rate = True # is_proof
future_also = is_proof
nat_msg_rate = False # not is_proof
coded_combination = is_proof
# if self.is_netcode():
# return self.get_netcode_region(convexify = convexify)
self.presolve()
if self.bnet.tsorted() is None:
raise ValueError("Non-strictly-causal cycle detected.")
return None
if convexify is None:
convexify = self.convexify_needed()
if convexify is not False:
if convexify is True:
convexify = Comp.index("Q_o")
if latent_iid is not False:
if latent_iid is True:
latent_iid = Comp.index("B_o")
latent_iid_rv = Comp.empty()
if latent_iid is not False and isinstance(latent_iid, Comp):
latent_iid_rv += latent_iid
# auxs = self.get_aux()
# reg_out_aux = auxs.copy()
reg_out_aux = Comp.empty()
bnet_out = BayesNet()
reg_out_map = {}
clen = 1 if oneshot else 3 if future else 2
ttr = [1, 0, 2]
# for v0, v1 in self.sublist:
# print(str(v0) + " " + str(v1))
rv_wfuture = Comp.empty()
to_csiszar = []
rv_series = Comp.empty()
rv_timedep = Comp.empty()
for x in self.bnet.allcomp():
removable = False
removed = False
if leaf_remove and self.bnet.get_children(x).isempty() and x.ispresent(self.get_rv_sub(x)):
removable = True
# print(str(x) + " " + str(self.is_rate(x)) + " " + str(self.get_rv_rates(x)))
t = None
if self.is_rate(x):
# t = CompArray([x.copy()])
t = x.copy()
elif oneshot:
t = CompArray([x.copy()])
elif future:
t = CompArray.series_sym(x)
if removable:
t[1] = Comp.empty()
t[2] = Comp.empty()
removed = True
else:
rv_wfuture += x
else:
t = CompArray.series_sym(x, suff = None)
reg_out_map[x.get_name()] = t
if isinstance(t, CompArray):
for tt in t:
rv_timedep += tt
for tt in t[1:]:
rv_series += tt
if not removed:
to_csiszar.append(t)
if len(latent_iid_rv):
bnet_out += latent_iid_rv
for (a, b) in self.bnet.edges():
# if not self.get_rv_sub(b).ispresent(b):
# continue
am = reg_out_map[a.get_name()]
bm = reg_out_map[b.get_name()]
node = self.find_node_rv_out(b)
if node is None:
bnet_out += (am.swapped_id(0, 1), bm.swapped_id(0, 1))
else:
# if convexify is False and self.is_rate(b) and am.allcomp().ispresent(rv_series):
# rv_series += bm.allcomp()
if node.rv_in_causal.ispresent(a) and isinstance(am, CompArray) and isinstance(bm, CompArray):
for t1 in range(clen):
for t2 in range(t1, clen):
if convexify is False:
bnet_out += (am[ttr[t1]] + bm[ttr[t1]] + latent_iid_rv, bm[ttr[t2]])
else:
bnet_out += (am[ttr[t1]] + latent_iid_rv, bm[ttr[t2]])
else:
if convexify is False:
bnet_out += (am.swapped_id(0, 1).allcomp() + latent_iid_rv, bm.swapped_id(0, 1).allcomp())
else:
for tbm in bm.swapped_id(0, 1).allcomp():
bnet_out += (am.swapped_id(0, 1).allcomp() + latent_iid_rv, tbm)
for b in self.bnet.allcomp():
bm = reg_out_map[b.get_name()]
if not isinstance(bm, CompArray):
continue
node = self.find_node_rv_out(b)
if node is None:
continue
for a in node.rv_in_scausal:
am = reg_out_map[a.get_name()]
if not isinstance(am, CompArray):
continue
for t1 in range(clen):
for t2 in range(t1, clen):
if t1 == 1 and t2 == 1:
continue
if convexify is False:
bnet_out += (am[ttr[t1]] + bm[ttr[t1]], bm[ttr[t2]])
else:
bnet_out += (am[ttr[t1]], bm[ttr[t2]])
if convexify is False:
for x in self.bnet.allcomp():
parent_empty = self.bnet.get_parents(x).isempty()
if parent_empty:
xm = reg_out_map[x.get_name()]
if isinstance(xm, CompArray) and len(xm) >= 3:
bnet_out += (xm[1], xm[2])
rvnats = Comp.empty()
for x in self.bnet.allcomp():
parent_empty = self.bnet.get_parents(x).isempty()
if parent_empty or self.find_node_rv_out(x) is not None:
xm = reg_out_map[x.get_name()]
if isinstance(xm, CompArray):
for txm in (xm[1:] if parent_empty else xm):
if txm.isempty():
continue
series_parent = bnet_out.get_parents(txm).ispresent(rv_series)
if series_parent:
continue
rvnats += txm
if convexify is not False:
bnet_out += (convexify, txm)
else:
if not txm.ispresent(rv_series[0]):
if convexify_test:
# print((rv_series[0], txm))
return True
bnet_out += (rv_series[0], txm)
# print(rv_series)
r_fcn = Region.universe()
for x in self.bnet.allcomp():
if self.bnet.is_fcn(x):
xm = reg_out_map[x.get_name()]
bnet_out.set_fcn(xm)
elif node_fcn and self.find_node_rv_out(x) is not None:
xm = reg_out_map[x.get_name()]
if isinstance(xm, CompArray):
for txm in xm:
if txm.isempty():
continue
series_parent = bnet_out.get_parents(txm).ispresent(rv_series)
if node_fcn_force or convexify is not False or series_parent:
bnet_out.set_fcn(txm)
# print(txm)
if convexify is False:
for b in rv_series:
b2 = bnet_out.get_parents(txm) - xm.allcomp() + b
if not b2.super_of(txm):
r_fcn &= Expr.Hc(txm, b2) == 0
else:
if future:
bnet_out.set_fcn(xm)
bnet_out_final = bnet_out.copy()
if remove_created:
for x in self.bnet.allcomp():
if self.created_rv.ispresent(x) and self.is_rate(x):
xm = reg_out_map[x.get_name()]
for a in xm:
bnet_out_final = bnet_out_final.eliminated(a)
bnet_out_final = bnet_out_final.scc()
# if convexify_test:
# return False
if not convexify_test:
# print(bnet_out)
self.bnet_out_final = bnet_out_final
self.bnet_out = bnet_out
self.bnet_out_arrays = []
for x in self.bnet.allcomp():
xm = reg_out_map[x.get_name()]
if isinstance(xm, CompArray):
t = xm.swapped_id(0, 1).allcomp()
if len(t) > 1:
self.bnet_out_arrays.append(t)
r = bnet_out_final.get_region().add_meta("pf_note", ["Bayesian network"])
# if not self.reg.isuniverse():
# # r = r & self.reg
# r = r.and_cause_consequence(self.reg)
if not convexify_test and not oneshot and future and add_csiszar_sum:
# r &= csiszar_sum(*(list(x for _, x in reg_out_map.items())))
r &= csiszar_sum(*to_csiszar).add_meta("pf_note", ["Csiszar sum"]).add_meta("dual_weight", 5.0)
r &= r_fcn
if convexify is not False:
reg_out_aux += convexify
for x in rv_series:
r &= Expr.Hc(convexify, x) == 0
reg_out_aux += bnet_out.allcomp() - self.bnet.allcomp()
# for b in self.bnet.allcomp():
# b2 = self.get_rv_sub(b)
# if b2.ispresent(b):
# continue
# if self.is_rate(b):
# rate = self.get_rv_rate(b)
# r &= rate >= 0
def map_getlist(a):
t = None
for c in a:
cm = reg_out_map[c.get_name()].copy()
if t is None:
t = cm.copy()
else:
if isinstance(t, Comp) and isinstance(cm, CompArray):
t, cm = cm, t
if isinstance(cm, Comp) and isinstance(t, CompArray):
t[0] = t[0] + cm
else:
t = t + cm
return t
def map_substitute(r, a, b):
am = map_getlist(a)
bm = map_getlist(b)
if isinstance(am, Comp) or isinstance(bm, Comp):
r.substitute(am.allcomp(), bm.allcomp())
else:
for a2, b2 in zip(am, bm):
r.substitute(a2, b2)
# If all sequences are i.i.d. conditional on codebook
# if latent_iid is not False and convexify is not False:
# seqlist = [Comp.empty(), Comp.empty(), Comp.empty()]
# for a in self.bnet.index.comprv:
# am = map_getlist(a)
# if isinstance(am, CompArray):
# for i, t in enumerate(am):
# seqlist[i] += t
# if not seqlist[0].isempty():
# r &= (Expr.I(seqlist[0], convexify) == 0)
# r &= indep(seqlist[0], seqlist[1], seqlist[2]).conditioned(convexify)
# Source-channel data processing
if (convexify_test or (not oneshot and convexify is not False)) and all(node.rv_in_scausal.isempty() for node in self.nodes):
cconvexify = Comp.empty()
if convexify is not False:
cconvexify = convexify
seqs = sum((a for a in self.bnet.allcomp() if isinstance(reg_out_map[a.get_name()], CompArray)), Comp.empty())
srcs = self.get_srcs()
chans = self.get_chans()
for sa in igen.subset(srcs.inter(seqs), minsize=1):
sam = map_getlist(sa)
for sb, sc2 in chans:
if not seqs.super_of(sb):
continue
sbm = map_getlist(sb)
for sc in igen.subset(sc2.inter(seqs), minsize=1):
scm = map_getlist(sc)
if self.bnet.check_ic(Expr.Ic(sa, sc, sb)):
for sd in igen.subset(seqs - sa - sb, minsize=1):
sdm = map_getlist(sd)
if self.bnet.check_ic(Expr.Ic(sa + sb, sd, sc)):
cexpr = Expr.Ic(sbm[0], scm[0], cconvexify) - Expr.Ic(sam[0], sdm[0], cconvexify)
cexpr.simplify_quick()
if cexpr.isnonneg():
continue
# print(sa, " - ", sb, " - ", sc, " - ", sd)
if convexify_test:
return True
r &= (cexpr >= 0).add_meta("pf_note", ["data proc. ", sa, " - ", sb, " - ", sc, " - ", sd])
# for sa in igen.subset(seqs, minsize=1):
# sam = map_getlist(sa)
# if bnet_out.check_ic(Expr.Ic(sam[0], sam[1], cconvexify)):
# for sc in igen.subset(seqs - sa, minsize=1):
# sb = sum((self.bnet.get_parents(c) for c in sc), Comp.empty()) - sc
# if sb.isempty() or not seqs.super_of(sb):
# continue
# sbm = map_getlist(sb)
# scm = map_getlist(sc)
# if bnet_out.check_ic(Expr.Ic(scm[0], sum(sbm) + scm[1], sbm[0] + cconvexify)):
# if bnet_out.check_ic(Expr.Ic(sum(sam), sum(scm), sum(sbm))):
# for sd in igen.subset(seqs - sa - sb, minsize=1):
# sdm = map_getlist(sd)
# if bnet_out.check_ic(Expr.Ic(sum(sam) + sum(sbm), sum(sdm), sum(scm))):
# if convexify_test:
# return True
# r &= (Expr.Ic(sam[0], sdm[0], cconvexify) <= Expr.Ic(sbm[0], scm[0], cconvexify)).add_meta("pf_note", ["data proc. ", sa, " - ", sb, " - ", sc, " - ", sd])
if convexify_test:
return False
msgs = Comp.empty()
msgnats = Comp.empty()
for v0, v1 in self.sublist:
if isinstance(v1, Expr) and not isinstance(v0, Expr):
msgs += reg_out_map[v0[0].get_name()]
vout = self.bnet.get_parents(v0[0])
if vout.isempty():
msgnats += reg_out_map[v0[0].get_name()]
if decoding_rate:
for a in self.bnet.index.comprv:
a2 = self.get_rv_ratervs(a)
if a != a2:
vout = self.bnet.get_parents(a)
voutm = map_getlist(vout)
voutm0 = Comp.empty()
voutm1 = Comp.empty()
voutm2 = None
if isinstance(voutm, CompArray):
voutm0 = voutm[0]
if len(voutm) >= 2:
voutm1 = voutm[1]
if len(voutm) >= 3:
voutm2 = voutm[2]
else:
continue
for a3 in a2:
rate = self.get_rv_rates(a3)
for cmsg in igen.subset(msgnats - a3, minsize = 1):
r &= (Expr.Ic(a3, voutm0, voutm1 + cmsg) >= Expr.Ic(a3, voutm0, voutm1)
).add_meta("pf_note", ["indep. of msgs ", a3, ", ", cmsg])
if future_also and voutm2 is not None:
r &= (Expr.Ic(a3, voutm0, voutm2 + cmsg) >= Expr.Ic(a3, voutm0, voutm2)
).add_meta("pf_note", ["indep. of msgs ", a3, ", ", cmsg])
r &= (Expr.Ic(a3, voutm0, voutm1 + latent_iid_rv) >= rate
).add_meta("pf_note", ["decode ", a3]).add_meta("dual_weight", 0.1)
if future_also and voutm2 is not None:
r &= (Expr.Ic(a3, voutm0, voutm2 + latent_iid_rv) >= rate
).add_meta("pf_note", ["decode ", a3]).add_meta("dual_weight", 0.15)
codedmsgs = []
for v0, v1 in self.sublist:
if isinstance(v0, Comp):
rv_wfuture -= v0
if isinstance(v1, Expr) and not isinstance(v0, Expr):
map_substitute(r, v0, v0[0])
isnat = False
vout = self.bnet.get_parents(v0[0])
if vout.isempty():
vout = self.bnet.get_children(v0[0])
isnat = True
voutm = map_getlist(vout)
voutm0 = Comp.empty()
voutm1 = Comp.empty()
voutm2 = None
if isinstance(voutm, CompArray):
voutm0 = voutm[0]
if len(voutm) >= 2:
voutm1 = voutm[1]
if len(voutm) >= 3:
voutm2 = voutm[2]
else:
voutm0 = voutm
am = reg_out_map[v0[0].get_name()]
if isnat:
if oneshot:
r &= Expr.H(am) >= v1
r &= v1 >= 0
else:
for cmsg in igen.subset(msgnats - am, minsize = 1):
r &= (Expr.Ic(am, voutm0, voutm1 + cmsg) >= Expr.Ic(am, voutm0, voutm1)
).add_meta("pf_note", ["indep. of msgs ", am, ", ", cmsg])
if nat_msg_rate:
r &= (Expr.Ic(am, voutm0, voutm1 + latent_iid_rv) >= v1
).add_meta("pf_note", ["rate of ", am])
r &= v1 >= 0
# r &= Expr.Ic(am, voutm0, voutm1) == v1
else:
if oneshot:
r &= Expr.H(am) <= v1
else:
codedmsgs.append((am, voutm0, voutm1, voutm2))
for ctv in igen.subset(rv_wfuture - vout):
tva = map_getlist(ctv)
tv = None
if tva is None:
tv = Comp.empty()
else:
tv = tva[0] + tva[1] + tva[2]
# print(am)
# print(ctv)
# print(tva)
# print(rv_wfuture)
# print()
for cmsg in igen.subset(msgs - am):
r &= (Expr.Ic(am, voutm0, voutm1 + tv + cmsg + latent_iid_rv) <= v1
).add_meta("pf_note", ["rate of ", am])
if coded_combination and tva is not None:
r &= (Expr.Ic(am, voutm0 + tva[0], voutm1 + tva[1] + cmsg + latent_iid_rv) <= v1
).add_meta("pf_note", ["rate of ", am]).add_meta("dual_weight", 0.5)
if future_also and voutm2 is not None:
r &= (Expr.Ic(am, voutm0, voutm2 + tv + cmsg + latent_iid_rv) <= v1
).add_meta("pf_note", ["rate of ", am])
if coded_combination and tva is not None:
r &= (Expr.Ic(am, voutm0 + tva[0], voutm2 + tva[2] + cmsg + latent_iid_rv) <= v1
).add_meta("pf_note", ["rate of ", am]).add_meta("dual_weight", 0.5)
# r &= Expr.Ic(am, voutm0, voutm1) == v1
# print(v0)
# print(v1)
reg_out_aux += am
else:
map_substitute(r, v0, v1)
# print(r)
for tozero, toconst in self.sublist_const:
r.substitute(tozero, toconst)
if not include_nondecode_series or not include_last_future:
rv_decode = Comp.empty()
rv_nondecode = Comp.empty()
for a in self.bnet.index.comprv:
alist = map_getlist(a)
if not (isinstance(alist, CompArray) and len(alist) > 1):
continue
if any(self.find_node_rv_out(x) is not None for x in self.bnet.get_children(a)):
rv_decode += a
else:
rv_nondecode += a
rv_toelim = Comp.empty()
if not include_nondecode_series:
for a in rv_nondecode:
alist = map_getlist(a)
for i, b in enumerate(alist):
if i:
rv_toelim += b
if not include_last_future:
if len(rv_decode):
alist = map_getlist(rv_decode[len(rv_decode) - 1])
if len(alist) >= 3:
rv_toelim += alist[2]
if not rv_toelim.isempty():
r = r.exists_quick(rv_toelim, method = "ci")
if not self.reg.isuniverse():
# r = r & self.reg
r = r.and_cause_consequence(self.reg)
reg_out_aux = reg_out_aux.inter(r.allcomprv())
r = r.exists(reg_out_aux)
if aux is not None and isinstance(aux, int):
aux_pairs = []
aux_force = msgs
if convexify is not False:
aux_force = aux_force + convexify
channel_ins = Comp.empty()
for a in self.bnet.index.comprv:
voutm = map_getlist(a)
if isinstance(voutm, CompArray) and len(voutm) >= 3:
if r.ispresent(voutm[1]) and r.ispresent(voutm[2]):
aux_pairs.append((voutm[1], voutm[2]))
pa = self.bnet.get_parents(a)
if pa.ispresent(msgnats):
channel_ins += sum(voutm, Comp.empty())
msgs_sorted = Comp.empty()
for a in self.bnet.index.comprv:
if msgs.ispresent(a):
msgs_sorted += a
msgs_sorted += msgs
def score_fcn(a):
r = 0
msgs_vis = [False] * len(msgs_sorted)
hastime = False
hassingleconvexify = False
if len(a) >= 2:
for b1, b2 in itertools.permutations(a, 2):
if b1.super_of(b2):
return None
for b in a:
hasconvexify = False
hasseries = b.ispresent(rv_series)
if hasseries:
hastime = True
if convexify is not False and b.ispresent(convexify):
hasconvexify = True
hastime = True
r -= 1
if hasseries:
return None
nmsg = 0
for i, msg in enumerate(msgs_sorted):
if b.ispresent(msg):
if msgs_vis[i]:
r += 100
else:
r -= i
msgs_vis[i] = True
nmsg += 1
if nmsg == 0:
if not hasconvexify:
r += 20
elif nmsg > 1:
r += nmsg * 70
if len(b) == 1:
if not hasconvexify:
r += 5
elif len(b) > 2:
r += len(b) * 5
if b.ispresent(channel_ins):
r += 30
if hasconvexify:
if len(b) == 1:
hassingleconvexify = True
if len(b) > 1:
r += len(b) * 10
for t in msgs_vis:
if not t:
r += 10
if not hastime:
r += 20
return (r, -1 if hassingleconvexify else 0)
if balanced is None:
balanced = self.is_balanced()
if isinstance(dual_bound_values, str) and dual_bound_values == "2fold":
dual_bound_values = self.nfold_concmodel()
r.add_real_to(dual_bound_values)
r = r.aux_reduced(aux, aux_pairs = aux_pairs, aux_force = aux_force, score_fcn = score_fcn,
balanced = balanced, dual_bound_values = dual_bound_values, **kwargs)
elif aux is not None and isinstance(aux, (list, Comp, str)):
aux2 = []
for a in aux if isinstance(aux, list) else [aux]:
if isinstance(a, tuple):
aux2.append((a[0], rv(r.find(a[1]))))
else:
aux2.append(rv(r.find(a)))
if balanced is None:
balanced = self.is_balanced()
if isinstance(dual_bound_values, str) and dual_bound_values == "2fold":
dual_bound_values = self.nfold_concmodel()
r.add_real_to(dual_bound_values)
r = r.aux_set(aux2, balanced = balanced, dual_bound_values = dual_bound_values, **kwargs)
if not skip_simplify:
return r.simplified()
else:
return r
def get_outer_nfold(self, n = 2, **kwargs):
return self.nfold(n).get_outer(**kwargs) / n
def proof_inner(self, r, *args, **kwargs):
"""Proof of an inner bound.
"""
r2 = r & self.get_region()
if not self.reg.isuniverse():
r2 = r2.and_cause_consequence(self.reg)
r_inner = None
# with PsiOpts(simplify_aux_all = False):
# r_inner = self.get_inner(*args, **kwargs)
r_inner = self.get_inner(*args, is_proof = True, target = r2, **kwargs)
if r_inner is None:
return None
r_inner = r_inner.tounion()
for x, c in r_inner.regs:
pf = None
with PsiOpts(proof_new = True):
if r2 >> x.exists(r_inner.aux):
return PsiOpts.get_proof()
return None
def proof_outer(self, r, future = None, *args, **kwargs):
"""Proof of an outer bound.
"""
if future is None:
future = [False, True]
elif isinstance(future, bool):
future = [future]
for f in future:
r_outer = self.get_outer(*args, is_proof = True, future = f, **kwargs)
with PsiOpts(proof_new = True):
if r_outer >> r:
return PsiOpts.get_proof()
return None
def optimum(self, v, b, sn, name = None, inner = True, outer = True, inner_kwargs = None, outer_kwargs = None, tighten = False, **kwargs):
"""Return the variable obtained from maximizing (sn=1)
or minimizing (sn=-1) the expression v over variables b (Comp, Expr or list)
"""
# if self.is_netcode():
# outer = False
if inner_kwargs is None:
inner_kwargs = dict()
if outer_kwargs is None:
outer_kwargs = dict()
if inner:
inner = self.get_inner(**inner_kwargs)
else:
inner = None
if outer:
outer = self.get_outer(**outer_kwargs)
else:
outer = None
if name is not None:
name += PsiOpts.settings["fcn_suffix"]
if inner is not None:
return inner.optimum(v, b, sn, name = name, reg_outer = outer, tighten = tighten, quick = None, quick_outer = True, **kwargs)
if outer is not None:
return outer.optimum(v, b, sn, name = name, tighten = tighten, quick = True, **kwargs)
return None
def maximum(self, expr, vs, **kwargs):
"""Return the variable obtained from maximizing the expression expr
over variables vs (Comp, Expr or list)
"""
return self.optimum(expr, vs, 1, **kwargs)
def minimum(self, expr, vs, **kwargs):
"""Return the variable obtained from minimizing the expression expr
over variables vs (Comp, Expr or list)
"""
return self.optimum(expr, vs, -1, **kwargs)
def node_groups(self):
r = []
for node in self.nodes:
a = self.bnet.get_parents(node.rv_out)
b = node.rv_out.copy()
for t in r:
if a.super_of(t[1]) and (t[1] + t[0]).super_of(a):
t[0] += b
break
else:
r.append([b, a, node])
return r
def node_group_info(self):
groups = self.node_groups()
return groups
def channel_groups(self):
r = []
for a in self.bnet.allcomp():
if self.find_node_rv_out(a) is not None:
continue
pa = self.bnet.get_parents(a)
for gin, gout in r:
if pa == gin + gout:
gout += a
break
else:
r.append((pa, a))
return r
def graph(self, lr = True, enc_node = True, channel_box = True, rv_on_edge = False, seq_op = ("append_super", "n"), recovered_op = ("append", "'"), ortho = False, **kwargs):
"""Return the graphviz digraph of the network that can be displayed in the console.
"""
if graphviz is None:
raise RuntimeError("Requires graphviz. Please install it first.")
r = graphviz.Digraph()
if lr:
r.graph_attr["rankdir"] = "LR"
if ortho:
r.graph_attr["splines"] = "ortho"
for key, value in kwargs.items():
r.graph_attr[key] = str(value)
node_shape = PsiOpts.settings.get("codingmodel_node_shape", "rect")
node_fillcolor = PsiOpts.settings.get("codingmodel_node_fillcolor", "white")
channel_shape = PsiOpts.settings.get("codingmodel_channel_shape", "rect")
channel_fillcolor = PsiOpts.settings.get("codingmodel_channel_fillcolor", "white")
groups = self.node_groups()
channel_groups = []
if channel_box:
channel_groups = self.channel_groups()
rvs = self.bnet.allcomp()
edges = []
if enc_node:
for i, (b, a, node) in enumerate(groups):
cname = "enc_" + str(a) + "_" + str(b)
label = str(i + 1)
if node.label is not None:
label = iutil.latex_to_text(node.label, style="graphviz")
r.node(cname, label, shape = node_shape, style = "filled", fillcolor = node_fillcolor)
for ai in a:
edges.append((ai.get_name(), cname, "", None))
# r.edge(ai.get_name(), cname)
for bi in b:
edges.append((cname, bi.get_name(), "", None))
# r.edge(cname, bi.get_name())
for ai in node.rv_in_scausal:
edges.append((ai.get_name(), cname, "", "dashed"))
# r.edge(ai.get_name(), cname, style = "dashed")
rvs -= b
if channel_box:
for i, (a, b) in enumerate(channel_groups):
if a.isempty() and len(b) <= 1:
continue
cname = "ch_" + str(a) + "_" + str(b)
if a.isempty():
label = prob_symbol(b).graphviz_text()
# label = "P(" + b.graphviz_text() + ")"
else:
label = prob_symbol(b|a).graphviz_text()
# label = "P(" + b.graphviz_text() + "|" + a.graphviz_text() + ")"
r.node(cname, label, shape = channel_shape, style = "filled", fillcolor = channel_fillcolor)
for ai in a:
edges.append((ai.get_name(), cname, "", None))
# r.edge(ai.get_name(), cname)
for bi in b:
edges.append((cname, bi.get_name(), "", None))
# r.edge(cname, bi.get_name())
rvs -= b
for (a, b) in self.bnet.edges():
if b in rvs:
edges.append((a.get_name(), b.get_name(), "", None))
# r.edge(a.get_name(), b.get_name())
for a in self.bnet.allcomp():
shape = "plaintext" #"oval"
node = self.find_node_rv_out(a)
asub = self.get_rv_sub(a)
if recovered_op is not None and asub != a:
asub = asub.name_operation(recovered_op)
if seq_op is not None and not self.is_rate(a):
asub = asub.name_operation(seq_op)
label = asub.graphviz_text()
if node is not None and not enc_node:
shape = node_shape
if rv_on_edge:
inedges = [(x, y, t, s) for x, y, t, s in edges if y == a.get_name()]
outedges = [(x, y, t, s) for x, y, t, s in edges if x == a.get_name()]
if len(inedges) == 1 and len(outedges) == 1:
edges = [(x, y, t, s) for x, y, t, s in edges if x != a.get_name() and y != a.get_name()]
edges.append((inedges[0][0], outedges[0][1], label, outedges[0][3]))
else:
if len(inedges) + len(outedges) > 1:
r.node(a.get_name(), "", shape="circle", width="0.05", height="0.05", style="filled", fillcolor="black")
else:
r.node(a.get_name(), "", shape="none", width="0", height="0")
for i in range(len(edges)):
if edges[i][1] == a.get_name():
edges[i] = (edges[i][0], edges[i][1], label, "no_arrowhead" if len(outedges) else edges[i][3])
elif edges[i][0] == a.get_name() and len(inedges) == 0:
edges[i] = (edges[i][0], edges[i][1], label, edges[i][3])
else:
if shape == "plaintext":
# r.node(a.get_name(), label, shape = shape, margin = "0", color = "none", width = "0", height = "0")
# r.node(a.get_name(), label, shape = shape, color = "none", width = "0", height = "0")
r.node(a.get_name(), label, shape = shape, color = "none", width = "0")
else:
r.node(a.get_name(), label, shape = shape)
for x, y, t, s in edges:
edge_kwargs = {}
if t != "":
edge_kwargs["label"] = t
if s is not None:
if s == "no_arrowhead":
edge_kwargs["arrowhead"] = "none"
else:
edge_kwargs["style"] = s
r.edge(x, y, **edge_kwargs)
# if s is None:
# if t == "":
# r.edge(x, y)
# else:
# r.edge(x, y, label=t)
# else:
# if t == "":
# r.edge(x, y, style=s)
# else:
# r.edge(x, y, label=t, style=s)
return r
def graph_outer(self, **kwargs):
return self.bnet_out.graph(groups = self.bnet_out_arrays, **kwargs)
class CommEnc(IBaseObj):
def __init__(self, rv_out, msgs, rv_in_causal = None):
# self.rv_in = Comp.empty()
# self.msgs = []
# for x in list_in:
# if isinstance(x, Comp):
# self.rv_in += x
# else:
# self.msgs.append(x)
self.msgs = msgs
self.rv_out = rv_out
if rv_in_causal is None:
self.rv_in_causal = Comp.empty()
else:
self.rv_in_causal = rv_in_causal
def record_to(self, index, skip_msg = False):
# self.rv_in.record_to(index)
self.rv_out.record_to(index)
self.rv_in_causal.record_to(index)
if not skip_msg:
for x in self.msgs:
for y in x:
y.record_to(index)
class CommDec(IBaseObj):
def __init__(self, rv_in, msgs, msgints = None):
self.rv_in = rv_in
self.msgs = msgs
self.msgints = msgints
def record_to(self, index):
self.rv_in.record_to(index)
for x in self.msgs:
for y in x:
y.record_to(index)
if self.msgints is not None:
for x in self.msgints:
for y in x:
y.record_to(index)
class CommModel(IBaseObj):
def __init__(self, bnet = None, reg = None, nature = None):
if bnet is None:
self.bnet = BayesNet()
else:
self.bnet = bnet
if reg is None:
self.reg = Region.universe()
else:
self.reg = reg
self.bnet_in = None
self.reg_in = None
self.bnet_out = None
self.reg_out = None
self.reg_out_aux = None
self.reg_out_tsrv = None
self.reg_out_vs = None
self.reg_out_map = None
self.msgs = None
self.maux_rt = None
self.maux_rtsub = None
self.decs = None
if nature is None:
self.nature = Comp.empty()
else:
self.nature = nature
self.enclist = []
self.declist = []
def __iadd__(self, other):
if isinstance(other, CommEnc):
self.enclist.append(other)
elif isinstance(other, CommDec):
self.declist.append(other)
else:
self.bnet += other
return self
def create_reg_in(self, name_prefix = "A_"):
self.bnet_in = self.bnet.copy()
self.msgs = []
self.decs = []
index = IVarIndex()
self.reg.record_to(index)
for x in self.enclist:
x.record_to(index)
for x in self.declist:
x.record_to(index)
for x in self.enclist:
cauxall = Comp.empty()
rv_in = self.bnet.get_parents(x.rv_out)
rv_in = rv_in - x.rv_in_causal
for m in x.msgs:
caux = None
if isinstance(m, Expr):
tname = name_prefix + str(m)
tname = index.name_avoid(tname)
caux = Comp.rv(tname)
caux.record_to(index)
else:
caux = m[1].copy()
self.msgs.append([rv_in.copy(), m[0], caux, x.rv_out.copy()])
cauxall += caux
self.bnet_in += (rv_in, cauxall)
self.bnet_in += (cauxall, x.rv_out)
# self.reg_in &= (Expr.Ic(cauxall, index.comprv - cauxall - (x.rv_in + x.rv_out),
# x.rv_in + x.rv_out) == 0)
# for x in self.declist:
# self.decs.append([x.rv_in.copy(), x.msgs.copy(), iutil.copy(x.msgints)])
self.reg_in = self.bnet_in.get_region() & self.reg
def get_encout(self):
r = Comp.empty()
for x in self.enclist:
r += x.rv_out
return r
def enc_id_get_aux(self, i):
r = Comp.empty()
for rv_in, rtbs, b, rv_out in self.msgs:
if self.enclist[i].rv_out.ispresent(rv_out):
r += b
return r
def create_reg_out(self, future = True):
self.create_reg_in("M_")
auxs = sum(x[2] for x in self.msgs)
self.reg_out_aux = auxs.copy()
#self.reg_out_tsrv = Comp.rv("#TS")
self.bnet_out = BayesNet()
self.reg_out_map = {}
clen = 3 if future else 2
ttr = [1, 0, 2]
for x in self.bnet.allcomp():
t = None
if future:
t = CompArray.series_sym(x)
else:
t = CompArray.series_sym(x, suff = None)
self.reg_out_map[x.get_name()] = t
for (a, b) in self.bnet.edges():
self.bnet_out += (self.reg_out_map[a.get_name()], self.reg_out_map[b.get_name()])
for i in range(len(self.enclist)):
auxs = self.enc_id_get_aux(i)
rv_out = self.enclist[i].rv_out
rv_in_causal = self.enclist[i].rv_in_causal
rv_out_map = self.reg_out_map[rv_out.get_name()]
for x in rv_out_map:
self.bnet_out += (auxs, x)
for pa in self.bnet.get_parents(rv_out):
pa_map = self.reg_out_map[pa.get_name()]
if rv_in_causal.ispresent(pa):
for t1 in range(clen):
for t2 in range(t1 + 1, clen):
self.bnet_out += (pa_map[ttr[t1]], rv_out_map[ttr[t2]])
else:
for t1 in range(clen):
for t2 in range(clen):
if t1 != t2:
self.bnet_out += (pa_map[t1], rv_out_map[t2])
self.reg_out = self.bnet_out.get_region() & self.reg
if future:
self.reg_out &= csiszar_sum(*(list(x for _, x in self.reg_out_map.items()) + list(auxs)))
self.reg_out_aux += self.bnet_out.allcomp() - self.bnet.allcomp()
def get_outer(self, future = True, convexify = False, skip_simplify = False):
self.create_reg_out(future = future)
r = self.reg_out.copy()
rts = self.get_rates()
for rt in rts:
r &= rt >= 0
for dec in self.declist:
a = dec.rv_in
rts = dec.msgs
ap = self.reg_out_map[a.get_name()][1]
for rt in rts:
for i, (_, rtbs, b, _) in enumerate(self.msgs):
if rtbs.ispresent(rt):
r &= rt <= Expr.Ic(b, a, ap)
aux = self.reg_out_aux
# if convexify is not False:
# if convexify is True:
# convexify = Comp.index("Q_T")
# r.condition(convexify)
# encout = self.get_encout()
# encout_map
# r &= markov(r.allcomprv() - encout - convexify, encout, convexify)
# aux += convexify
r = r.exists(aux)
if not skip_simplify:
return r.simplified()
else:
return r
def get_decreq(self):
r = []
for dec in self.declist:
a = dec.rv_in
rts = dec.msgs
reqs = []
reqints = None
if dec.msgints is not None:
reqints = []
for i, (_, rtbs, b, _) in enumerate(self.msgs):
if any(rtbs.ispresent(rt) for rt in rts):
reqs.append(i)
if dec.msgints is not None and any(rtbs.ispresent(rt) for rt in dec.msgints):
reqints.append(i)
# cknown = a.copy()
for j in range(len(reqs) - 1, -1, -1):
# cdec = a.copy()
# for cc in cknown:
# ccan = self.bnet_in.get_ancestors(cc)
# if cknown.super_of(ccan):
# cdec += cc
# r.append((reqs[j], cdec))
# TODO ?????????????
r.append((reqs[j], a, reqs[j+1:], None if reqints is None else [a for a in reqints if a < reqs[j]]))
# r.append((reqs[j], a+cknown))
# cknown += self.msgs[reqs[j]][2]
return r
def get_ratereq(self, known, reqprevs, i, tlist):
verbose = PsiOpts.settings.get("verbose_commmodel", False)
r = Region.universe()
lmask = (1 << i) + sum(1 << t for t in tlist)
submask = sum(1 << j for j in range(i) if self.maux_rtsub[j] is not None)
for tmask in range(1 << len(tlist)):
mask = (1 << i) + sum(1 << tlist[j] for j in range(len(tlist)) if tmask & (1 << j))
expr = Expr.zero()
for j in range(i + 1):
if mask & (1 << j):
expr += self.maux_rt[j]
if self.maux_rtsub[j] is not None:
expr += self.maux_rtsub[j]
expr -= mi_rect_max(self.msgs[j][2], [known] +
[(self.msgs[j2][2], self.maux_rtsub[j2]) if self.maux_rtsub[j2] is not None
else self.msgs[j2][2] for j2 in reqprevs] +
[(self.msgs[j2][2], self.maux_rtsub[j2]) if self.maux_rtsub[j2] is not None
else self.msgs[j2][2] for j2 in range(j) if lmask & (1 << j2)])
expr += mi_rect_max([self.msgs[j][0]] +
[(self.msgs[j2][2], self.maux_rtsub[j2]) if self.maux_rtsub[j2] is not None
else self.msgs[j2][2] for j2 in range(j)],
self.msgs[j][2])
expr.simplify_quick()
r &= (expr <= 0)
r_st = None
if verbose:
print("========== inner bound ==========")
print("known:" + str(known) + " decoded:" + str(sum(self.msgs[t][2] for t in reqprevs))
+ " nonunique:" + str(sum(self.msgs[t][2] for t in tlist)) + " target:" + str(self.msgs[i][2]))
r_st = str(r)
print(r_st)
r = r.flattened(minmax_elim = True)
if verbose:
r_st2 = str(r)
if r_st2 != r_st:
print("expanded:")
print(r_st2)
print("")
return r
def get_ratereq_old(self, known, i, tlist):
r = Region.universe()
lmask = (1 << i) + sum(1 << t for t in tlist)
submask = sum(1 << j for j in range(i) if self.maux_rtsub[j] is not None)
for tmask in range(1 << len(tlist)):
mask = (1 << i) + sum(1 << tlist[j] for j in range(len(tlist)) if tmask & (1 << j))
for cp in itertools.product(*[igen.subset_mask(submask & ((1 << j) - 1))
for j in range(i + 1) if mask & (1 << j)]):
expr = Expr.zero()
caux = Comp.empty()
#cauxmiss = Comp.empty()
ji = 0
for j in range(i + 1):
if mask & (1 << j):
# TODO ??????????????????????????????????????
expr += self.maux_rt[j] - Expr.I(self.msgs[j][2], known + caux)
if self.maux_rtsub[j] is not None:
expr += self.maux_rtsub[j]
cauxmiss = self.msgs[j][0].copy()
for j2 in range(j):
if submask & (1 << j2):
if cp[ji] & (1 << j2):
expr -= self.maux_rtsub[j2]
cauxmiss += self.msgs[j2][2]
else:
cauxmiss += self.msgs[j2][2]
if not cauxmiss.isempty():
expr += Expr.I(self.msgs[j][2], cauxmiss)
ji += 1
if lmask & (1 << j):
caux += self.msgs[j][2]
#cauxmiss += self.msgs[j][2]
expr.simplify_quick()
r &= (expr <= 0)
# print("==========")
# print(str(known) + " " + str(self.msgs[i][2]) + " " + str(sum(self.msgs[t][2] for t in tlist)))
# print(r)
return r
def get_rates(self):
r = Expr.zero()
for _, rts, a, _ in self.msgs:
for rt in rts:
if not r.ispresent(rt):
r += rt
return r
def get_aux(self):
r = Comp.empty()
for _, rts, a, _ in self.msgs:
r += a
return r
def get_inner_iter(self, subcodebook = True, skip_simplify = False):
self.maux_rt = [rt[0] for _, rt, a, _ in self.msgs]
self.maux_rtsub = [None] * len(self.msgs)
if subcodebook is True:
self.maux_rtsub = [Expr.real("#RS_" + x.get_name()) for _, rt, x, _ in self.msgs]
self.maux_rtsub[-1] = None
elif isinstance(subcodebook, Comp):
self.maux_rtsub = [Expr.real("#RS_" + x.get_name())
if subcodebook.ispresent(x) else None for _, rt, x, _ in self.msgs]
self.maux_rtsub[-1] = None
subcodebook = True
reqs = self.get_decreq()
#r = Region.universe()
r = self.reg_in.copy()
rts = self.get_rates()
for rt in rts:
r &= rt >= 0
if subcodebook:
for rt in self.maux_rtsub:
if rt is not None:
r &= rt >= 0
for reqi, known, reqprevs, reqints in reqs:
r2 = Region.empty()
if reqints is None:
for mask in range(1 << reqi):
tlist = []
for i in range(reqi):
if mask & (1 << i):
tlist.append(i)
r2 |= self.get_ratereq(known, reqprevs, reqi, tlist)
else:
r2 = self.get_ratereq(known, reqprevs, reqi, reqints)
if not skip_simplify:
r2 = r2.simplified()
r &= r2
if subcodebook and any(x is not None for x in self.maux_rtsub):
r.eliminate(sum(x for x in self.maux_rtsub if x is not None))
return r
def rate_splitting(self, r):
if self.msgs is None:
self.create_reg_in()
rttmp = [Expr.real("#RT_" + x.get_name()) for _, rt, x, _ in self.msgs]
for rt1, rt2 in zip((rt for _, rt, x, _ in self.msgs), rttmp):
r.substitute(rt1, rt2)
decflag = [0] * len(self.msgs)
for i, (_, rtbs, b, _) in enumerate(self.msgs):
for j, dec in enumerate(self.declist):
a = dec.rv_in
rts = dec.msgs
if any(rtbs.ispresent(rt) for rt in rts):
decflag[i] |= 1 << j
rtauxs = []
n = len(self.msgs)
rn = Region.universe()
bexps = [Expr.zero() for i in range(n)]
for i, (_, rtbs, b, _) in enumerate(self.msgs):
vis = []
cexp = Expr.zero()
for mask in range(1, 1 << n):
if any(mask | v == mask for v in vis):
continue
cdecflag = 0
for i2 in range(n):
if mask & (1 << i2):
cdecflag |= decflag[i2]
if cdecflag | decflag[i] != cdecflag:
continue
vis.append(mask)
crtaux = Expr.real("#RTA_" + str(i) + "_" + str(mask))
rn &= crtaux >= 0
for i2 in range(n):
if mask & (1 << i2):
bexps[i2] += crtaux
cexp += crtaux
rtauxs.append(crtaux)
rn &= rtbs <= cexp
for i in range(n):
rn &= bexps[i] <= rttmp[i]
r &= rn
rts = self.get_rates()
for rt in rts:
r &= rt >= 0
# print(r)
r.eliminate(sum(rttmp + rtauxs, Expr.zero()))
# print(r)
return r
def get_inner(self, subcodebook = True, rate_split = False, shuffle = False,
convexify = False, convexify_diag = False,
skip_simplify = False, skip_simplify_iter = False):
r = None
self.create_reg_in()
aux = self.get_aux()
if shuffle:
r = Region.empty()
tmaux = self.msgs
for cmaux in itertools.permutations(tmaux):
tbnet = self.bnet.copy()
for i in range(len(cmaux) - 1):
tbnet += (cmaux[i][3], cmaux[i+1][3])
if tbnet.iscyclic():
# print("CYCLIC " + str(cmaux))
continue
self.msgs = list(cmaux)
r |= self.get_inner_iter(subcodebook = subcodebook, skip_simplify = skip_simplify_iter)
self.msgs = tmaux
else:
r = self.get_inner_iter(subcodebook = subcodebook, skip_simplify = skip_simplify_iter)
if convexify_diag:
r = r.convexified_diag(self.get_rates(), skip_simplify = skip_simplify)
if convexify is not False:
if convexify is True:
convexify = Comp.index("Q_i")
r.condition(convexify)
encout = self.get_encout()
r &= markov(r.allcomprv() - encout - convexify, encout, convexify)
aux += convexify
if rate_split:
r = self.rate_splitting(r)
if not skip_simplify:
#return r.distribute().simplified()
r = r.exists(aux)
# print(r)
r.simplify_union()
# print(r)
r.remove_missing_aux()
return r.simplified_quick()
else:
return r.exists(aux)
# Generators
class igen:
def subset_mask(mask):
x = 0
while True:
yield x
if x >= mask:
break
x = ((x | ~mask) + 1) & mask
def partition_mask(mask, n, max_mask = None):
if n == 1:
if mask != 0 and not (max_mask is not None and mask > max_mask):
yield (mask,)
return
for xmask in igen.subset_mask(mask):
if xmask == 0:
continue
if max_mask is not None and xmask > max_mask:
break
for t in igen.partition_mask(mask & ~xmask, n - 1, xmask):
yield t + (xmask,)
def partition(x, n):
m = len(x)
for mask in igen.partition_mask((1 << m) - 1, n):
yield tuple(sum(t for i, t in enumerate(x) if xmask & (1 << i)) for xmask in mask)
def subset(x, minsize = 0, maxsize = 100000, size = None, reverse = False, coeffmode = 0, replacement = False, zero_elem = 0):
"""Iterate sum of subset.
Parameters:
coeffmode : Set to 0 to only allow positive terms.
Set to 1 to allow positive/negative terms, but not all negative.
Set to 2 to allow positive/negative terms.
Set to -1 to allow positive/negative terms, but not all positive/negative.
"""
#if isinstance(x, types.GeneratorType):
if not hasattr(x, '__len__'):
x = list(x)
if size is not None:
minsize = size
maxsize = size
iterfcn = None
if replacement:
iterfcn = itertools.combinations_with_replacement
else:
iterfcn = itertools.combinations
n = len(x)
cr = None
if reverse:
maxsize = min(maxsize, n)
cr = range(maxsize, minsize - 1, -1)
else:
cr = range(minsize, maxsize + 1)
for s in cr:
if s > n:
return
if s == 0:
if len(x) > 0:
if isinstance(x[0], Comp):
yield Comp.empty()
elif isinstance(x[0], Expr):
yield Expr.zero()
elif isinstance(x[0], IBaseArray):
yield type(x[0]).zeros(shape = x[0].shape)
else:
yield 0
else:
if isinstance(x, Comp):
yield Comp.empty()
elif isinstance(x, Expr):
yield Expr.zero()
elif isinstance(x, IBaseArray):
yield type(x).zeros(shape = x.shape)
else:
yield 0
elif s == 1:
if coeffmode == 2:
for y in x:
yield y
yield -y
else:
for y in x:
yield y
else:
if coeffmode == 0:
for comb in iterfcn(x, s):
yield sum(comb, zero_elem)
else:
for comb in iterfcn(x, s):
for sflag in range(int(coeffmode == -1),
(1 << s) - int(coeffmode == 1 or coeffmode == -1)):
yield sum((comb[i] * (-1 if sflag & (1 << i) else 1)
for i in range(s)), zero_elem)
def pm(x):
"""Plus or minus.
"""
for a in x:
yield a
yield -a
def sI(x, maxsize = 100000, cond = True, ent = True, pm = False):
"""Iterate mutual information of variables in x.
Parameters:
cond : Whether to include conditional mutual information.
ent : Whether to include entropy.
pm : Whether to include both positive and negative.
"""
n = len(x)
def xmask(mask):
return sum((x[i] for i in range(n) if mask & (1 << i)), Comp.empty())
for mask in igen.subset([1 << i for i in range(n)], minsize = 1, maxsize = maxsize):
bmask = mask
if not cond:
if ent:
yield Expr.H(xmask(bmask))
if pm:
yield -Expr.H(xmask(bmask))
while True:
if not cond:
if bmask <= 0:
break
amask = mask - bmask
if amask > bmask:
break
if amask > 0:
yield Expr.I(xmask(amask), xmask(bmask))
if pm:
yield -Expr.I(xmask(amask), xmask(bmask))
else:
a0mask = mask - bmask
if ent:
if a0mask > 0:
yield Expr.Hc(xmask(a0mask), xmask(bmask))
if pm:
yield -Expr.Hc(xmask(a0mask), xmask(bmask))
while a0mask > 0:
a1mask = mask - bmask - a0mask
if a1mask > a0mask:
break
if a1mask > 0:
yield Expr.Ic(xmask(a1mask), xmask(a0mask), xmask(bmask))
if pm:
yield -Expr.Ic(xmask(a1mask), xmask(a0mask), xmask(bmask))
a0mask = (a0mask - 1) & (mask - bmask)
if bmask <= 0:
break
bmask = (bmask - 1) & mask
def test(x, fcn, sgn = 0, yield_set = False):
"""Test fcn for values in generator x.
Set sgn = 1 if fcn is increasing.
Set sgn = -1 if fcn is decreasing.
"""
ub = MonotoneSet(sgn = 1)
lb = MonotoneSet(sgn = -1)
for a in x:
if sgn != 0:
if a in ub:
continue
if a in lb:
continue
if fcn(a):
if not yield_set:
yield a
if sgn > 0:
ub.add(a)
if yield_set:
yield (a, ub)
elif sgn < 0:
lb.add(a)
if yield_set:
yield (a, lb)
else:
if sgn > 0:
lb.add(a)
elif sgn < 0:
ub.add(a)
def remove_comments(s):
com_chars = ("#", "%")
lines = s.split("\n")
r = ""
for x in lines:
if x.lstrip().startswith(com_chars):
continue
mi = len(x)
for c in com_chars:
i = x.find(c)
if i >= 0:
mi = min(mi, i)
if r != "":
r += "\n"
r += x[:mi]
return r
def parse_sanitize(s, grammar = False):
if not grammar:
s = remove_comments(s)
c = ""
if not grammar:
c = " "
s = s.replace(r"\{", c + "BBRACE_L" + c)
s = s.replace(r"\}", c + "BBRACE_R" + c)
s = s.replace(r"\,", " ")
s = s.replace(r"\!", " ")
s = s.replace(r"\;", " ")
s = s.replace(r"\\", " ")
s = s.replace(r"\[", " ")
s = s.replace(r"\]", " ")
s = s.replace(r"\quad", " ")
s = s.replace(r"\qquad", " ")
if not grammar:
s = s.replace(r"\\", " ")
s = s.replace(r"\leftrightarrow", c + "lrarrow" + c)
s = s.replace(r"\leftarrow", c + "larrow" + c)
s = s.replace(r"\rightarrow", c + "rarrow" + c)
s = s.replace(r"\left", " ")
s = s.replace(r"\right", " ")
s = s.replace(r"\big", " ")
s = s.replace(r"\Big", " ")
s = s.replace(r"\bigg", " ")
s = s.replace(r"\Bigg", " ")
s = s.replace(r"\small", " ")
s = s.replace(r"\tiny", " ")
s = s.replace(r"\scriptsize", " ")
s = s.replace(r"\footnotesize", " ")
s = s.replace(r"\normalsize", " ")
s = s.replace(r"\large", " ")
s = s.replace(r"\Large", " ")
s = s.replace(r"\LARGE", " ")
s = s.replace(r"\huge", " ")
s = s.replace(r"\Huge", " ")
s = s.replace(r"\displaystyle", " ")
s = s.replace(r"\nonumber", " ")
s = s.replace(r"\notag", " ")
s = s.replace(r"\exists", c + "exists" + c)
s = s.replace(r"\forall", c + "forall" + c)
for i in range(5):
s = s.replace(r"," + " " * i + r"exists", "comma_exists")
s = s.replace(r"," + " " * i + r"forall", "comma_forall")
s = s.replace(r"{\perp" + " " * i + r"\perp}", c + r"\perp" + c)
s = s.replace("\\", c + "BACKSLASH_")
return s
@lark_v_args(inline = True)
class RegionTree(lark_Transformer):
grammar = r"""
?start: region_or_expr
| NAME "=" region_or_expr -> set_var
| "check" region_or_expr -> check
| "latex" region_or_expr -> latex
| "assume" region -> assume
| "clear" "assume" -> clear_assume
| "style" NAME -> style
?region_or_expr: region
| expr
| "simplify" region_or_expr -> simplified
?region: region_aux
| region (">>" | "implies" | "\Rightarrow" | "\to" | "\rightarrow") region_aux -> rshift
?region_aux: region_union
| "exists" comp ("." | ":" | "st") region_aux -> exists_r
| "forall" comp ("." | ":" | "st") region_aux -> forall_r
| region_aux ", exists" comp -> exists
| region_aux ", forall" comp -> forall
?region_union: region_inter
| region_union "|" region_inter -> or_
| region_union "or" region_inter -> or_
| region_union "OR" region_inter -> or_
| region_union "\vee" region_inter -> or_
?region_inter: region_atom
| region_inter "&" region_atom -> and_
| region_inter_name "and" region_atom -> and_
| region_inter_name "AND" region_atom -> and_
| region_inter_name "\wedge" region_atom -> and_
| region_inter "," region_atom -> and_
?region_inter_name: region_inter
| NAME -> region_var
?region_atom: "(" region ")"
| "{" region "}"
| "\{" region "\}"
| "\begin" "{" dis_latex_block "}" ["{" NAME "}"] region "\end" "{" dis_latex_block "}" -> taken2
| expr ( ["&"] rel expr)+ -> rels
| "~" region_atom -> not_
| "not" region_atom -> not_
| "NOT" region_atom -> not_
| "\lnot" region_atom -> not_
| "markov" "(" comp_closer ("," comp_closer)+ ")" -> markov
| "indep" "(" comp_closer ("," comp_closer)+ ")" -> indep
| region_atom "." "exists" comp -> exists
| region_atom "." "forall" comp -> forall
| region_atom "." "simplified" "(" ")" -> simplified
| region_atom "." "latex" "(" ")" -> latex
| (comp_single | comp_comma_b) (("markov" | "->" | "<->" | "\to" | "\rightarrow" | "\leftrightarrow") (comp_single | comp_comma_b))+ -> markov
| (comp_single | comp_comma_b) (("indep" | "\perp" | ("\perp" "\perp") | ".") (comp_single | comp_comma_b))+ -> indep
| "assumption" -> assumption
| "@" NAME -> region_var
?region_atom_name: region_atom
| NAME -> region_var
?rel: ("==" | "=") -> rel_eq
| ("!=" | "\neq") -> rel_ne
| "<" -> rel_lt
| ">" -> rel_gt
| ("<=" | "\le" | "\leq") -> rel_le
| (">=" | "\ge" | "\geq") -> rel_ge
?expr: expr_product
| expr "+" expr_product -> add
| expr "-" expr_product -> sub
?expr_product: expr_atom
| expr_number expr_atom -> mul
| expr_product ("*" | "\cdot" | "\times") expr_atom -> mul
| expr_product "/" expr_atom -> truediv
?expr_atom: expr_number
| "H" "(" comp ")" -> ent
| "H" "(" comp "|" comp ")" -> entc
| "I" "(" comp_and ")" -> mi
| "I" "(" comp_and "|" comp ")" -> mic
| "-" expr_atom -> neg
| NAME -> expr_var
| NAME "{" (NAME | NUMBER) "}" -> expr_var
?expr_number: cnumber
| "\frac" "{" expr "}" "{" expr "}" -> truediv
| "(" expr ")"
| "{" expr "}"
| "\{" expr "\}"
?cnumber: NUMBER -> number
?comp_and: comp
| comp_and "&" comp -> and_
| comp_and "\wedge" comp -> and_
| comp_and ";" comp -> and_
| comp_and ":" comp -> and_
?comp: comp_closer
| comp "," comp_closer -> add
?comp_closer: comp_atom
| comp_closer "+" comp_atom -> add
| comp_closer comp_atom -> add
?comp_atom: comp_single
| "(" comp ")"
| "{" comp "}"
| "\{" comp "\}"
?comp_single: NAME -> comp_var
| NAME "{" (NAME | NUMBER) "}" -> comp_var
| NAME "^" "{" (NAME | NUMBER) "}" -> comp_var
| NAME "^" (NAME | NUMBER) -> comp_var
?comp_comma: "(" comp_comma ")"
| (comp_single | comp_comma) "," comp -> add
?comp_comma_b: "(" comp_comma ")"
?dis_latex_block: ("array" | "align" | ("align" "*") | "equation" | ("equation" "*") | "eqnarray" | ("eqnarray" "*") | "split" | ("split" "*") | "multline" | ("multline" "*"))
%import common.CNAME -> NAME
%import common.NUMBER
%import common.WS
%ignore WS
%ignore "$"
%ignore "$$"
"""
grammar = parse_sanitize(grammar, grammar = True)
from operator import eq, ne, lt, le, gt, ge, and_, or_, not_, add, sub, mul, truediv, neg, rshift
number = float
def __init__(self):
self.varmap = {}
def clear(self):
self.varmap = {}
def take1(self, *args):
return args[1]
def taken1(self, *args):
return args[-1]
def taken2(self, *args):
return args[-2]
def set_var(self, name, value):
self.varmap[name] = value
return value
def type_var(self, *args, cur_type = "comp"):
name = ""
if len(args) == 2 and args[0].endswith("_"):
name = str(args[0]) + "{" + str(args[1]) + "}"
else:
name = "".join(str(a) for a in args)
if name in self.varmap:
return self.varmap[name]
if cur_type == "expr":
r = Expr.real(name)
elif cur_type == "region":
r = Region.universe()
else:
r = Comp.rv(name)
self.varmap[name] = r
return r
def ent(self, x):
return Expr.H(x)
def entc(self, x, y):
return Expr.Hc(x, y)
def mi(self, x):
return I(x)
def mic(self, x, y):
return I(x | y)
def expr_var(self, *args):
return self.type_var(*args, cur_type = "expr")
def comp_var(self, *args):
return self.type_var(*args, cur_type = "comp")
def region_var(self, *args):
return self.type_var(*args, cur_type = "region")
def simplified(self, x):
return x.simplified_truth()
def check(self, x):
return x.check()
def latex(self, x):
return x.latex()
def assume(self, x):
x.assume()
return self.assumption()
def assumption(self):
truth = PsiOpts.settings["truth"]
if truth is not None:
return truth.copy()
else:
return Region.universe()
def clear_assume(self):
PsiOpts.setting(truth = None)
def add_list(self, x, y):
if isinstance(x, Comp):
x = [x]
if isinstance(y, Comp):
y = [y]
return x + y
def markov(self, *x):
return markov(*x)
def indep(self, *x):
return indep(*x)
def exists(self, x, y):
return x.exists(y)
def forall(self, x, y):
return x.forall(y)
def exists_r(self, x, y):
return y.exists(x)
def forall_r(self, x, y):
return y.forall(x)
def rels(self, *x):
r = None
for i in range(0, len(x) - 2, 2):
a, rel, b = x[i], x[i + 1], x[i + 2]
t = None
if rel == "eq":
t = a == b
elif rel == "ne":
t = a != b
elif rel == "lt":
t = a < b
elif rel == "gt":
t = a > b
elif rel == "le":
t = a <= b
elif rel == "ge":
t = a >= b
if r is None:
r = t
else:
r &= t
return r
def rel_eq(self):
return "eq"
def rel_ne(self):
return "ne"
def rel_lt(self):
return "lt"
def rel_gt(self):
return "gt"
def rel_le(self):
return "le"
def rel_ge(self):
return "ge"
def style(self, s):
if s == "std" or s == "standard":
PsiOpts.setting(str_style = "std")
elif s == "code" or s == "psitip":
PsiOpts.setting(str_style = "code")
elif s == "latex":
PsiOpts.setting(str_style = "latex")
else:
return "Unrecognized style: " + s + ". Options are: std, code, latex"
class RegionParser:
default_parser = None
def __init__(self):
if lark is None:
raise RuntimeError("Requires Lark. Please install it first.")
self.tree = RegionTree()
# self.parser = lark.Lark(RegionTree.grammar, parser = "earley", transformer = self.tree)
self.parser = lark.Lark(RegionTree.grammar, parser = "lalr", transformer = self.tree)
def parse(self, s):
s = parse_sanitize(s)
if "\n" in s:
try:
return self.parser.parse(s)
except lark.exceptions.LarkError as err:
lines = [x.strip() for x in s.split("\n")]
lines = [x for x in lines if len(x)]
s = ",\n".join(lines)
return self.parser.parse(s)
def clear(self):
self.tree.clear()
@staticmethod
def parse_default(s, clear = True):
if RegionParser.default_parser is None:
RegionParser.default_parser = RegionParser()
if clear:
RegionParser.default_parser.clear()
return RegionParser.default_parser.parse(s)
# Shortcuts
def alland(a):
"""Intersection of elements in list a (using operator &)."""
r = None
for x in a:
if r is None:
r = x
else:
r &= x
return r
def anyor(a):
"""Union of elements in list a (using operator |)."""
r = None
for x in a:
if r is None:
r = x
else:
r |= x
return r
def rv(*args, latex = None, index = False, split = True, alg = None, color = None, **kwargs):
"""Random variable"""
if len(args) == 0:
return Comp.empty()
if split:
args = [x for b in args for x in iutil.split_comma(b)]
if isinstance(latex, str):
if split:
latex = iutil.split_comma(latex)
else:
latex = [latex]
if latex is not None:
args = [x + "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + y for x, y in zip(args, latex)]
if isinstance(color, str):
if split:
color = iutil.split_comma(color)
else:
color = [color]
r = Comp.empty()
for i, a in enumerate(args):
t = iutil.ensure_comp(a)
if color is not None and len(color) > 0:
col = color[i % len(color)]
if col is not None and col != "":
t.add_markers([("color", col)])
if t is not None:
r += t
# if isinstance(a, str):
# r += Comp.rv(a)
# elif isinstance(a, IBaseObj):
# r += a.allcomprv_noaux()
if index:
r.add_markers([("index_shift", 0)])
for key, value in kwargs.items():
r.add_markers([(key, value)])
if alg is not None:
r.set_algtype(alg)
return r
def rv_seq(name, st, en = None, alg = None):
"""Sequence of random variables"""
r = Comp.array(name, st, en)
if alg is not None:
r.set_algtype(alg)
return r
def rv_array(name, st, en = None, alg = None):
"""Array of random variables"""
r = rv_seq(name, st, en, alg)
return CompArray.make(*r)
def rv_series(name, future = True, sufp = "P", suff = "F"):
"""Array of random variables, with past, current and future random variables"""
if not future:
suff = None
return CompArray.series_sym(name, sufp = sufp, suff = suff)
def real(*args, latex = None, split = True, color = None):
"""Real variable"""
if split:
args = [x for b in args for x in iutil.split_comma(b)]
if isinstance(latex, str):
if split:
latex = iutil.split_comma(latex)
else:
latex = [latex]
if latex is not None:
args = [x + "@@" + str(PsiOpts.STR_STYLE_LATEX) + "@@" + y for x, y in zip(args, latex)]
if isinstance(color, str):
if split:
color = iutil.split_comma(color)
else:
color = [color]
is_one = len(args) == 1
r = []
for i, a in enumerate(args):
if isinstance(a, str):
t = Expr.real(a)
if color is not None and len(color) > 0:
col = color[i % len(color)]
if col is not None and col != "":
t.terms[0][0].x[0].add_markers([("color", col)])
r.append(t)
elif iutil.isnumeric(a):
r.append(Expr.const(a))
elif isinstance(a, IBaseObj):
for b in a.allcomprealvar_exprlist():
r.append(b.copy())
is_one = False
if is_one and len(r) == 1:
return r[0]
return ExprArray(r)
def real_array(name, st, en = None):
"""Array of random variables"""
t = rv_seq(name, st, en)
return ExprArray([Expr.real(a.name) for a in t.varlist])
def real_seq(name, st, en = None):
"""Array of random variables"""
t = rv_seq(name, st, en)
return ExprArray([Expr.real(a.name) for a in t.varlist])
def expr(*args):
"""Convert to an expression"""
if len(args) == 0:
return Expr.zero()
r = None
for a in args:
t = iutil.ensure_expr(a, strict = False)
if t is not None:
if r is None:
r = t.copy()
else:
r += t
return r
def region(*args):
"""Convert to a region"""
if len(args) == 0:
return Region.universe()
r = None
bn = None
for a in args:
if isinstance(a, (Comp, Term)) or (isinstance(a, tuple) and all(isinstance(x, Comp) for x in a)):
if bn is None:
bn = BayesNet()
bn += a
else:
t = iutil.ensure_region(a)
if t is not None:
if r is None:
r = t.copy()
else:
r &= t
if r is None:
if bn is None:
return None
else:
return bn.get_region()
else:
if bn is None:
return r
else:
return bn.get_region() & r
def rv_empty():
"""Empty random variable"""
return Comp.empty()
def zero():
"""Zero expression"""
return Expr.zero()
def universe():
"""Universal set.
Returns a region with no constraints.
"""
return Region.universe()
def empty():
"""Empty set.
Returns a region that is empty.
"""
return RegionOp.empty()
def emptyset():
"""Empty set.
Returns a region that is empty.
"""
return RegionOp.empty()
def H_inner(*args, prefer_multi = False):
if len(args) == 1:
x = args[0]
if isinstance(x, Comp):
if x.isempty():
return Expr.zero()
return Expr.H(x)
if isinstance(x, ConcDist):
return x.entropy()
if isinstance(x, Term):
return Expr([(x.copy(), 1.0)])
return Expr.H(iutil.ensure_comp(x))
else:
return Expr.fromterm(Term.from_symbols(args, prefer_multi=prefer_multi))
@fcn_list_to_list
def H(*args):
"""Entropy.
Returns a symbolic expression for the entropy
e.g. use H(X + Y | Z + W) for H(X,Y|Z,W)
"""
return H_inner(*args, prefer_multi=False)
@fcn_list_to_list
def I(*args):
"""Mutual information.
Returns a symbolic expression for the mutual information
e.g. use I(X + Y & Z | W) for I(X,Y;Z|W)
"""
return H_inner(*args, prefer_multi=True)
def I0(*args):
"""Shorthand for I(...)==0.
e.g. use I0(X + Y & Z | W) for I(X,Y;Z|W)==0
"""
return H(*args) == 0
@fcn_list_to_list
def Hc(x, z):
"""Conditional entropy.
Hc(X, Z) is the same as H(X | Z)
"""
return Expr.Hc(x, z)
@fcn_list_to_list
def Ic(x, y, z):
"""Conditional mutual information.
Ic(X, Y, Z) is the same as I(X & Y | Z)
"""
return Expr.Ic(x, y, z)
@fcn_list_to_list
def indep(*args):
"""Return Region where the arguments are independent."""
if any(isinstance(a, Term) for a in args):
term = Term.from_symbols(args, prefer_multi=True)
return indep(*(term.x)).conditioned(term.z)
args = [iutil.ensure_comp(t) for t in args]
r = Region.universe()
for i in range(1, len(args)):
r &= Expr.I(iutil.sumlist(args[:i]), iutil.sumlist(args[i])) == 0
return r
def indep_across(*args):
"""Take several arrays, return Region where entries are independent across dimension."""
n = max([len(a) for a in args])
vec = [iutil.sumlist([a[i] for a in args if i < len(a)]) for i in range(n)]
return indep(*vec)
@fcn_list_to_list
def equiv(*args):
"""Return Region where the arguments contain the same information."""
if any(isinstance(a, Term) for a in args):
term = Term.from_symbols(args, prefer_multi=True)
return equiv(*(term.x)).conditioned(term.z)
args = iutil.type_coerce(args)
if len(args) <= 1:
return Region.universe()
r = Region.universe()
for i in range(1, len(args)):
if isinstance(args[0], Comp):
r &= (Expr.Hc(args[i], args[0]) == 0) & (Expr.Hc(args[0], args[i]) == 0)
elif isinstance(args[0], Expr):
r &= args[0] == args[i]
elif isinstance(args[0], Region):
r &= args[0] == args[i]
return r
@fcn_list_to_list
def markov(*args):
"""Return Region where the arguments form a Markov chain."""
if any(isinstance(a, Term) for a in args):
term = Term.from_symbols(args, prefer_multi=True)
return markov(*(term.x)).conditioned(term.z)
args = [iutil.ensure_comp(t) for t in args]
r = Region.universe()
for i in range(2, len(args)):
r &= Expr.Ic(iutil.sumlist(args[:i-1]), iutil.sumlist(args[i]), iutil.sumlist(args[i-1])) == 0
return r
def eqdist(*args):
"""Return Region where the argument lists have the same distribution.
Only equalities of entropies are enforced.
e.g. eqdist([X, Y], [Z, W])
"""
m = min(len(a) for a in args)
r = Region.universe()
for i in range(1, len(args)):
for mask in range(1, 1 << m):
x = Comp.empty()
y = Comp.empty()
for j in range(m):
if mask & (1 << j) != 0:
x += args[0][j]
y += args[i][j]
if x != y:
r &= Expr.H(x) == Expr.H(y)
return r
def eqdist_across(*args):
"""Take several arrays, return Region where entries have the same distribution
across dimension. Only equalities of entropies are enforced.
"""
n = min([len(a) for a in args])
vec = [[a[i] for a in args] for i in range(n)]
return eqdist(*vec)
def exchangeable(*args):
"""Return Region where the arguments are exchangeable random variables.
Only equalities of entropies are enforced.
e.g. exchangeable(X, Y, Z)
"""
r = Region.universe()
for tsize in range(1, len(args)):
cvar = sum(args[:tsize])
for comb in itertools.combinations(args, tsize):
tvar = sum(comb)
if tvar != cvar:
r &= Expr.H(cvar) == Expr.H(tvar)
return r
def iidseq(*args):
"""Return Region where the arguments form an i.i.d. sequence.
Only equalities of entropies are enforced.
e.g. iidseq(X, Y, Z), iidseq([X1,X2], [Y1,Y2], [Z1,Z2])
"""
return indep(*args) & eqdist(*args)
def iidseq_across(*args):
"""Take several arrays, return Region where entries are i.i.d.
across dimension. Only equalities of entropies are enforced.
"""
n = min([len(a) for a in args])
vec = [[a[i] for a in args] for i in range(n)]
return indep(*vec) & eqdist(*vec)
def bnet(*args):
"""Create a Bayesian network with a list of edges (each edge is a tuple),
or obtain the Bayesian network of a region.
"""
r = BayesNet()
for a in args:
if isinstance(a, Region):
r += a.get_bayesnet()
else:
r += a
return r
@fcn_list_to_list
def sfrl_cons(x, y, u, k = None, gap = None):
"""Strong functional representation lemma.
Li, C. T., & El Gamal, A. (2018). Strong functional representation lemma and
applications to coding theorems. IEEE Trans. Info. Theory, 64(11), 6967-6978.
"""
if k is None:
r = (Expr.Hc(y, x + u) == 0) & (Expr.I(x, u) == 0)
if gap is not None:
if not isinstance(gap, Expr):
gap = Expr.const(gap)
r &= Expr.Ic(x, u, y) <= gap
return r
else:
r = (Expr.Hc(k, x + u) == 0) & (Expr.Hc(y, u + k) == 0) & (Expr.I(x, u) == 0)
if gap is not None:
if not isinstance(gap, Expr):
gap = Expr.const(gap)
r &= Expr.H(k) <= Expr.I(x, y) + gap
return r
@fcn_list_to_list
def sunflower(*args):
"""Sunflower dependency.
"""
r = Region.universe()
for i in range(len(args)):
r &= indep(*(args[:i] + args[i+1:])).conditioned(args[i])
return r
@fcn_list_to_list
def stardep(*args):
"""Star dependency.
"""
SU = Comp.rv("SU").avoid(*args)
r = Region.universe()
for a in args:
r &= H(SU | a) == 0
return (r & indep(*args).conditioned(SU)).exists(SU)
def interactive_discussions(parties, messages, is_fcn=False):
"""Interactive discussions between parties.
Calling interactive_discussions([X, Y, Z], [M1, M2, M3, M4]) corresponds to the situation
where party 1, 2, 3 have X, Y, Z resp., party 1 broadcasts M1 (based on X), party 2 broadcasts M2 (based on Y, M1),
party 3 broadcasts M3, and party 1 broadcasts M4.
"""
n = len(parties)
m = len(messages)
X = rv_seq("#TMP_PARTY", n)
M = rv_seq("#TMP_MESSAGE", m)
bn = BayesNet()
bn += X
for i in range(m):
bn += (X[i % n] + M[:i], M[i])
if is_fcn:
bn.set_fcn(M[i])
r = bn.get_region()
for i in range(m):
r.substitute(M[i], messages[i])
for i in range(n):
r.substitute(X[i], parties[i])
return r
@fcn_list_to_list
def cardbd(x, n):
"""Return Region where the cardinality of x is upper bounded by n."""
if n <= 1:
return H(x) == 0
loge = PsiOpts.settings["ent_coeff"]
V = rv_seq("V", 0, n-1).avoid(x)
r = Expr.H(V[n - 2]) == 0
r2 = Region.universe()
for i in range(0, n - 1):
r2 &= Expr.Hc(V[i], V[i - 1] if i > 0 else x) == 0
r |= Expr.Hc(V[i - 1] if i > 0 else x, V[i]) == 0
r = r.implicated(r2, skip_simplify = True).forall(V)
return r & (H(x) <= numpy.log(n) * loge)
@fcn_list_to_list
def isbin(x):
"""Return Region where x is a binary random variable."""
return cardbd(x, 2)
@fcn_list_to_list
def isuniform(x):
"""Return Region where x is a uniformly distributed random variable.
Zhen Zhang and Raymond W Yeung, "A non-Shannon-type conditional inequality
of information quantities", IEEE Trans. Inf. Theory 43, 6 (1997), pp. 1982-1986.
"""
U = rv("U").avoid(x)
V = rv("V").avoid(x)
r = (H(x | U+V) == 0) & (H(U | x+V) == 0) & (H(V | x+U) == 0)
r &= indep(x, U) & indep(x, V) & indep(U, V)
return r.exists(U+V)
def sfrl(gap = None):
"""Strong functional representation lemma.
Li, C. T., & El Gamal, A. (2018). Strong functional representation lemma and
applications to coding theorems. IEEE Trans. Info. Theory, 64(11), 6967-6978.
"""
disjoint_id = iutil.get_count()
# SX, SY, SU = rv("SX", "SY", "SU")
SX = rv("SX", latex = "S_X")
SY = rv("SY", latex = "S_Y")
SU = rv("SU", latex = "S_U")
SX.add_markers([("disjoint", disjoint_id), ("nonempty", 1)])
SY.add_markers([("disjoint", disjoint_id), ("nonempty", 1)])
#SU.add_markers([("mustuse", 1)])
r = ((Expr.Hc(SY, SX + SU) == 0).add_meta("pf_note", ["SFRL fcn"])
& (Expr.I(SX, SU) == 0).add_meta("pf_note", ["SFRL indep."]))
if gap is not None:
if not isinstance(gap, Expr):
gap = Expr.const(gap)
r &= (Expr.Ic(SX, SU, SY) <= gap).add_meta("pf_note", ["SFRL gap"])
return r.exists(SU).forall(SX + SY)
def copylem(n = 2, m = 1):
"""Copy lemma: for any A, B, there exists C such that (A, B) has the same
distribution as (A, C), and B-A-C forms a Markov chain.
n, m are the dimensions of A, B respectively.
Z. Zhang and R. W. Yeung, "On characterization of entropy function via information inequalities,"
IEEE Trans. Inform. Theory, vol. 44, pp. 1440-1452, Jul 1998.
Randall Dougherty, Chris Freiling, and Kenneth Zeger. "Non-Shannon information
inequalities in four random variables." arXiv preprint arXiv:1104.3602 (2011).
"""
disjoint_id = iutil.get_count()
symm_id_x = iutil.get_count()
symm_id_y = iutil.get_count()
X = rv_seq("A", 0, n)
for i in range(n):
X[i].add_markers([("disjoint", disjoint_id), ("symm", symm_id_x), ("symm_nonempty", 1)])
Y = rv_seq("B", 0, m)
Z = rv_seq("C", 0, m)
for i in range(m):
Y[i].add_markers([("disjoint", disjoint_id), ("symm", symm_id_y), ("symm_nonempty", 1)])
return (eqdist(X + Y, X + Z) & markov(Y, X, Z)).exists(Z).forall(X + Y).add_meta("pf_note", ["copy lemma"])
def dblmarkov():
"""Double Markov property: If X-Y-Z and Y-X-Z are Markov chains, then there
exists W that is a function of X, a function of Y, and (X,Y)-W-Z is Markov chain.
Imre Csiszar and Janos Korner. Information theory: coding theorems for
discrete memoryless systems. Cambridge University Press, 2011.
"""
symm_id_x = iutil.get_count()
nonsubset_id_x = iutil.get_count()
X = rv("DX", latex = "D_X")
Y = rv("DY", latex = "D_Y")
Z = rv("DZ", latex = "D_Z")
W = rv("DW", latex = "D_W")
X.add_markers([("symm", symm_id_x), ("nonsubset", nonsubset_id_x), ("nonempty", 1)])
Y.add_markers([("symm", symm_id_x), ("nonsubset", nonsubset_id_x), ("nonempty", 1)])
Z.add_markers([("nonempty", 1)])
return ((markov(X, Y, Z) & markov(Y, X, Z))
>> ((H(W|X) == 0) & (H(W|Y) == 0) & markov(X+Y, W, Z)).exists(W)).forall(X+Y+Z).add_meta("pf_note", ["double Markov"])
def mmrv_thm(n = 2):
"""The non-Shannon inequality in the paper:
Makarychev, K., Makarychev, Y., Romashchenko, A., & Vereshchagin, N. (2002). A new class of
non-Shannon-type inequalities for entropies. Communications in Information and Systems, 2(2), 147-166.
"""
disjoint_id = iutil.get_count()
symm_id_x = iutil.get_count()
symm_id_u = iutil.get_count()
X = rv_seq("CX", 0, n)
for i in range(n):
X[i].add_markers([("disjoint", disjoint_id), ("symm", symm_id_x), ("symm_nonempty", 2)])
U = Comp.rv("CU")
V = Comp.rv("CV")
Z = Comp.rv("CZ")
U.add_markers([("nonempty", 1), ("symm", symm_id_u)])
V.add_markers([("nonempty", 1), ("symm", symm_id_u)])
Z.add_markers([("nonempty", 1)])
expr = H(X) + n * I(U & V & Z)
expr -= sum(I(U & V | Xi) for Xi in X)
expr -= sum(H(Xi) for Xi in X)
expr -= I(U+V & Z)
return (expr <= 0).forall(X + U + V + Z)
def zydfz_thm(mode = ""):
"""The non-Shannon inequalities in the paper:
Z. Zhang and R. W. Yeung, "On characterization of entropy function via information inequalities,"
IEEE Trans. Inform. Theory, vol. 44, pp. 1440-1452, Jul 1998.
Randall Dougherty, Christopher Freiling, and Kenneth Zeger. "Six new non-Shannon
information inequalities." 2006 IEEE International Symposium on Information Theory. IEEE, 2006.
"""
disjoint_id = iutil.get_count()
A = Comp.rv("A")
B = Comp.rv("B")
C = Comp.rv("C")
D = Comp.rv("D")
A.add_markers([("disjoint", disjoint_id), ("nonempty", 1)])
B.add_markers([("disjoint", disjoint_id), ("nonempty", 1)])
C.add_markers([("disjoint", disjoint_id), ("nonempty", 1)])
D.add_markers([("disjoint", disjoint_id), ("nonempty", 1)])
r = Region.universe()
if mode.lower() != "dfz":
r &= (2*I(C&D) <= I(A&B) + I(A&C+D) + 3*I(C&D|A) + I(C&D|B) ).add_meta("pf_note", ["ZY ineq."]) # ZY
if mode.lower() != "zy":
r &= (2*I(A&B) <= 3*I(A&B|C) + 3*I(A&C|B) + 3*I(B&C|A) + 2*I(A&D) + 2*I(B&C|D) ).add_meta("pf_note", ["DFZ ineq. 1"]) # DFZ1
r &= (2*I(A&B) <= 4*I(A&B|C) + I(A&C|B) + 2*I(B&C|A) + 3*I(A&B|D) + I(B&D|A) + 2*I(C&D)).add_meta("pf_note", ["DFZ ineq. 2"]) # DFZ2
r &= (2*I(A&B) <= 3*I(A&B|C) + 2*I(A&C|B) + 4*I(B&C|A) + 2*I(A&C|D) + I(A&D|C) + 2*I(B&D) + I(C&D|A)).add_meta("pf_note", ["DFZ ineq. 3"]) # DFZ3
r &= (2*I(A&B) <= 5*I(A&B|C) + 3*I(A&C|B) + I(B&C|A) + 2*I(A&D) + 2*I(B&C|D) ).add_meta("pf_note", ["DFZ ineq. 4"]) # DFZ4
r &= (2*I(A&B) <= 4*I(A&B|C) + 4*I(A&C|B) + I(B&C|A) + 2*I(A&D) + 3*I(B&C|D) + I(C&D|B) ).add_meta("pf_note", ["DFZ ineq. 5"]) # DFZ5
r &= (2*I(A&B) <= 3*I(A&B|C) + 2*I(A&C|B) + 2*I(B&C|A) + 2*I(A&B|D) + I(A&D|B) + I(B&D|A) + 2*I(C&D)).add_meta("pf_note", ["DFZ ineq. 6"]) # DFZ6
return r.forall(A+B+C+D)
def dfz_thm(ncopy = 3):
"""The non-Shannon inequalities in the paper:
Dougherty, Randall, Chris Freiling, and Kenneth Zeger. "Non-Shannon information
inequalities in four random variables." arXiv preprint arXiv:1104.3602 (2011).
"""
disjoint_id = iutil.get_count()
A = Comp.rv("A")
B = Comp.rv("B")
C = Comp.rv("C")
D = Comp.rv("D")
A.add_markers([("disjoint", disjoint_id), ("nonempty", 1)])
B.add_markers([("disjoint", disjoint_id), ("nonempty", 1)])
C.add_markers([("disjoint", disjoint_id), ("nonempty", 1)])
D.add_markers([("disjoint", disjoint_id), ("nonempty", 1)])
r = Region.universe()
cl = []
cl += [(4, 2, 1, 3, 1, 0, 2)]
if ncopy == 2:
cl += [
(5, 3, 1, 2, 0, 0, 2),
(4, 4, 1, 2, 1, 1, 2),
(3, 3, 3, 2, 0, 0, 2),
(3, 4, 2, 3, 1, 0, 2),
(3, 2, 2, 2, 1, 1, 2)
]
for a, b, c, d, e, f, g in cl:
r &= (2*I(A&B) <= a*I(A&B|C) + b*I(A&C|B) + c*I(B&C|A) + d*I(A&B|D) + e*I(A&D|B) + f*I(B&D|A) + g*I(C&D)
).add_meta("pf_note", ["DFZ ineq. 2 copy"])
if ncopy >= 3:
cl = [
(3, 4, 4, 4, 3, 1, 1, 3, 0),
(3, 9, 6, 1, 3, 0, 0, 3, 0),
(3, 4, 6, 6, 3, 0, 0, 3, 0),
(2, 3, 3, 1, 5, 2, 0, 2, 0),
(3, 4, 3, 3, 3, 3, 3, 3, 0),
(3, 6, 3, 1, 6, 3, 0, 3, 0),
(4, 5, 8, 8, 4, 1, 1, 4, 0),
(2, 4, 2, 1, 2, 0, 0, 2, 3),
(4, 5, 5, 5, 4, 4, 4, 4, 0),
(2, 3, 3, 2, 2, 0, 0, 2, 0),
(3, 7, 5, 1, 3, 1, 1, 3, 0),
(4, 6, 4, 3, 4, 2, 1, 4, 0),
(2, 5, 2, 1, 2, 0, 0, 2, 0),
(2, 4, 3, 1, 2, 0, 0, 2, 0),
(2, 4, 1, 2, 2, 3, 0, 2, 0),
(2, 4, 2, 1, 2, 4, 1, 2, 0)
]
cl += [
(3, 4, 9, 3, 6, 3, 0, 3, 0),
(3, 7, 4, 1, 4, 1, 0, 3, 0),
(3, 4, 6, 4, 4, 1, 0, 3, 0),
(4, 5, 17, 6, 6, 7, 0, 4, 0),
(4, 5, 17, 13, 6, 2, 0, 4, 0),
(3, 4, 7, 5, 3, 1, 0, 3, 0),
(6, 8, 9, 9, 6, 10, 1, 6, 0),
(6, 13, 20, 2, 9, 3, 0, 6, 0),
(4, 10, 15, 1, 4, 2, 2, 4, 0),
(4, 6, 11, 3, 6, 2, 0, 4, 0),
(3, 6, 6, 1, 5, 4, 0, 3, 0),
(3, 6, 8, 1, 3, 2, 2, 3, 0),
(4, 5, 6, 6, 4, 2, 2, 4, 0),
(3, 8, 6, 1, 3, 1, 0, 3, 0),
(4, 14, 10, 1, 6, 2, 0, 4, 0),
(3, 4, 4, 3, 3, 4, 2, 3, 0),
(4, 13, 9, 1, 7, 3, 0, 4, 0),
(6, 8, 16, 7, 6, 3, 3, 6, 0)
]
for a, b, c, d, e, f, g, h, i in cl:
r &= (a*I(A&B) <= b*I(A&B|C) + c*I(A&C|B) + d*I(B&C|A) + e*I(A&B|D) + f*I(A&D|B) + g*I(B&D|A) + h*I(C&D) + i*I(C&D|A)
).add_meta("pf_note", ["DFZ ineq. 3 copy"])
return r.forall(A+B+C+D)
def matus_thm(min_s, max_s):
"""The non-Shannon inequalities in the paper:
F. Matus, "Infinitely many information inequalities", Proc. IEEE International
Symposium on Information Theory, 2007
"""
disjoint_id = iutil.get_count()
A = Comp.rv("A")
B = Comp.rv("B")
C = Comp.rv("C")
D = Comp.rv("D")
A.add_markers([("disjoint", disjoint_id), ("nonempty", 1)])
B.add_markers([("disjoint", disjoint_id), ("nonempty", 1)])
C.add_markers([("disjoint", disjoint_id), ("nonempty", 1)])
D.add_markers([("disjoint", disjoint_id), ("nonempty", 1)])
r = Region.universe()
for s in range(min_s, max_s + 1):
r &= (s*I(A&B) <= (s*(s+3)/2) * I(A&B|C) + (s*(s+1)/2) * I(A&C|B) + I(B&C|A) + s*I(A&B|D) + s*I(C&D)
).add_meta("pf_note", ["Matus ineq. s = " + str(s)])
return r.forall(A+B+C+D)
def ainfdiv(n = 2, coeff = None, cadd = None):
"""The approximate infinite divisibility of information.
C. T. Li, "Infinite Divisibility of Information,"
arXiv preprint arXiv:2008.06092 (2020).
"""
X = Comp.rv("X")
Z = rv_seq("Z", 0, n)
if coeff is None:
coeff = 1.0 / (1.0 - (1.0 - 1.0 / n) ** n) / n
else:
coeff = coeff / n
if cadd is None:
cadd = 2.43
r = iidseq(*Z) & (H(X | Z) == 0)
r &= H(Z[0]) <= H(X) * coeff + cadd
return r.exists(Z).forall(X)
def exists_xor():
"""For any X, there exists Z0 uniformly distributed and Z1 = X XOR Z0.
"""
X = Comp.rv("X")
Z = rv_seq("Z", 0, 2)
r = indep(X, Z[0]) & indep(X, Z[1]) & (H(X | Z) == 0)
return r.exists(Z).forall(X).add_meta("pf_note", ["exists XOR"])
def exists_linear(n, p = 2, maxsize = None):
"""There exists linear combinations of n i.i.d. GF(p) elements. p must be a prime.
"""
if maxsize is None:
maxsize = n
vs = []
for mask in igen.subset([1 << i for i in range(n)], minsize = 1, maxsize = maxsize):
gens = []
first = True
for i in range(n):
if not (mask & (1 << i)):
gens.append([0])
continue
if first:
gens.append([1])
first = False
else:
gens.append(list(range(1, p)))
for t in itertools.product(*gens):
vs.append(t)
loge = PsiOpts.settings["ent_coeff"]
logp = numpy.log(p) * loge
r = Region.universe()
V = Comp.empty()
for v in vs:
cname = ""
cname_latex = ""
for i, a in enumerate(v):
if a == 0:
continue
if cname != "":
cname += "(+)"
if cname_latex != "":
cname_latex += r"\oplus "
# cname_latex += "+_{" + str(p) + "}"
if a != 1:
cname += str(a)
cname_latex += str(a)
cname += "V" + str(i)
cname_latex += "V_{" + str(i) + "}"
crv = rv(cname, latex = cname_latex, split = False)
r &= Expr.H(crv) == logp
V += crv
for basis in itertools.combinations(range(len(vs)), n):
mat = numpy.array([vs[i] for i in basis], dtype = numpy.int64)
if int(round(numpy.linalg.det(mat))) % p == 0:
continue
Vb = sum((V[i] for i in basis), Comp.empty())
r &= indep(*Vb)
r &= Expr.Hc(V - Vb, Vb) == 0
return r.exists(V).add_meta("pf_note", ["exists XOR"])
def existence(f, numarg = 2, nonempty = False):
"""A region for the existence of the random variable f(X, Y).
e.g. existence(meet), existence(mss)
"""
T = rv_seq("T", 0, numarg)
X = f(*T)
r = X.varlist[0].reg
if X.get_marker_key("symm_args") is not None: # r.issymmetric(T):
symm_id_x = iutil.get_count()
for i in range(numarg):
T[i].add_markers([("symm", symm_id_x)])
if X.get_marker_key("nonsubset_args") is not None:
nonsubset_id_x = iutil.get_count()
for i in range(numarg):
T[i].add_markers([("nonsubset", nonsubset_id_x)])
X = f(*T)
r = X.varlist[0].reg
X = X.copy_noreg()
if nonempty:
X.add_markers([("nonempty", 1)])
r.substitute(X, X)
return r.exists(X).forall(T)
def rv_bit():
"""A random variable with entropy 1."""
U = Comp.rv("BIT")
return Comp.rv_reg(U, H(U) == 1)
def exists_bit(n = 1):
"""There exists a random variable with entropy 1."""
U = rv_seq("BIT", 0, n)
return (alland([H(x) == 1 for x in U]) & indep(*U)).exists(U)
@fcn_list_to_list
def emin(*args):
"""Return the minimum of the expressions."""
R = real(iutil.fcn_name_maker("min", args, pname = "emin", lname = "\\min", fcn_suffix = False))
R = real(str(R))
r = universe()
for x in args:
r &= R <= x
return r.maximum(R, None, allow_reuse = True)
@fcn_list_to_list
def emax(*args):
"""Return the maximum of the expressions."""
R = real(iutil.fcn_name_maker("max", args, pname = "emax", lname = "\\max", fcn_suffix = False))
R = real(str(R))
r = universe()
for x in args:
r &= R >= x
return r.minimum(R, None, allow_reuse = True)
@fcn_list_to_list
def eabs(x):
"""Absolute value of expression."""
R = real(iutil.fcn_name_maker("abs", x, fcn_suffix = False))
return ((R >= x) & (R >= -x)).minimum(R, None, allow_reuse = True)
@fcn_list_to_list
def meet(*args):
"""Gacs-Korner common part.
Peter Gacs and Janos Korner. Common information is far less than mutual information.
Problems of Control and Information Theory, 2(2):149-162, 1973.
"""
U = Comp.rv(iutil.fcn_name_maker("meet", args)).avoid(*args)
V = Comp.rv("V").avoid(*args)
U.add_markers([("mustuse", 1), ("symm_args", 1), ("nonsubset_args", 1)])
V.add_markers([("nonempty", 1)])
r = Region.universe()
r2 = Region.universe()
for a in args:
r &= (Expr.Hc(U, a) == 0).add_meta("pf_note", ["meet"])
r2 &= (Expr.Hc(V, a) == 0).add_meta("pf_note", ["meet"])
r = r & (Expr.Hc(V, U) == 0).implicated(r2, skip_simplify = True).forall(V).add_meta("pf_note", ["meet maximal"])
ret = Comp.rv_reg(U, r, reg_det = True)
#ret.add_markers([("symm_args", 1), ("nonsubset_args", 1)])
return ret
@fcn_list_to_list
def mss(x, y):
"""Minimal sufficient statistic of x about y."""
U = Comp.rv(iutil.fcn_name_maker("mss", [x, y])).avoid(x, y)
V = Comp.rv("V").avoid(x, y)
U.add_markers([("mustuse", 1), ("nonsubset_args", 1)])
r = (Expr.Hc(U, x) == 0) & (Expr.Ic(x, y, U) == 0)
r2 = (Expr.Hc(V, x) == 0) & (Expr.Ic(x, y, V) == 0)
r = r & (Expr.Hc(U, V) == 0).implicated(r2, skip_simplify = True).forall(V)
ret = Comp.rv_reg(U, r, reg_det = True)
#ret.add_markers([("nonsubset_args", 1)])
return ret
@fcn_list_to_list
def sfrl_rv(x, y, gap = None):
"""Strong functional representation lemma.
Li, C. T., & El Gamal, A. (2018). Strong functional representation lemma and
applications to coding theorems. IEEE Trans. Info. Theory, 64(11), 6967-6978.
"""
U = Comp.rv(iutil.fcn_name_maker("sfrl", [x, y], pname = "sfrl_rv")).avoid(x, y)
#U = Comp.rv(y.tostring(add_bracket = True) + "%" + x.tostring(add_bracket = True))
r = (Expr.Hc(y, x + U) == 0) & (Expr.I(x, U) == 0)
if gap is not None:
if not isinstance(gap, Expr):
gap = Expr.const(gap)
r &= Expr.Ic(x, U, y) <= gap
return Comp.rv_reg(U, r, reg_det = False)
def esfrl_rv(x, y, gap = None):
"""Strong functional representation lemma, extended form.
Li, C. T., & El Gamal, A. (2018). Strong functional representation lemma and
applications to coding theorems. IEEE Trans. Info. Theory, 64(11), 6967-6978.
"""
U = Comp.rv(iutil.fcn_name_maker("esfrl", [x, y], pname = "esfrl_rv")).avoid(x, y)
K = Comp.rv(iutil.fcn_name_maker("esfrl_K", [x, y], pname = "esfrl_rv_K")).avoid(x, y)
r = (Expr.Hc(K, x + U) == 0) & (Expr.Hc(Y, U + K) == 0) & (Expr.I(x, U) == 0)
if gap is not None:
if not isinstance(gap, Expr):
gap = Expr.const(gap)
r &= Expr.H(K) <= Expr.I(x, y) + gap
return (Comp.rv_reg(U, r, reg_det = False), Comp.rv_reg(K, r, reg_det = False))
def copylem_rv(x, y):
"""Copy lemma: for any X, Y, there exists Z such that (X, Y) has the same
distribution as (X, Z), and Y-X-Z forms a Markov chain.
Z. Zhang and R. W. Yeung, "On characterization of entropy function via information inequalities,"
IEEE Trans. Inform. Theory, vol. 44, pp. 1440-1452, Jul 1998.
Randall Dougherty, Chris Freiling, and Kenneth Zeger. "Non-Shannon information
inequalities in four random variables." arXiv preprint arXiv:1104.3602 (2011).
"""
U = Comp.rv(iutil.fcn_name_maker("copy", [x, y], pname = "copylem_rv")).avoid(x, y)
r = eqdist(list(x) + [y], list(x) + [U]) & markov(y, sum(x), U)
return Comp.rv_reg(U, r, reg_det = False)
@fcn_list_to_list
def total_corr(*args):
"""Total correlation.
Watanabe S (1960). Information theoretical analysis of multivariate correlation,
IBM Journal of Research and Development 4, 66-82.
e.g. total_corr(X & Y & Z | W)
"""
x = Term.from_symbols(args, prefer_multi=True)
if isinstance(x, Comp):
return Expr.H(x)
return sum([Expr.Hc(a, x.z) for a in x.x]) - Expr.Hc(sum(x.x), x.z)
@fcn_list_to_list
def dual_total_corr(*args):
"""Dual total correlation.
Han T. S. (1978). Nonnegative entropy measures of multivariate symmetric
correlations, Information and Control 36, 133-156.
e.g. dual_total_corr(X & Y & Z | W)
"""
x = Term.from_symbols(args, prefer_multi=True)
if isinstance(x, Comp):
return Expr.H(x)
r = Expr.Hc(sum(x.x), x.z)
for i in range(len(x.x)):
r -= Expr.Hc(x.x[i], sum([x.x[j] for j in range(len(x.x)) if j != i]) + x.z)
return r
def prob_markov_constraints(Xs, U, p):
P = ConcModel()
p0 = p.marginal(0)
P[Xs[0]] = p0
P[U | Xs[0]] = "var,rand"
vars = [P[U | Xs[0]]]
Xt = Comp.empty()
for i, X in enumerate(Xs):
if i > 0:
P[X | U] = "var,rand"
vars.append(P[X | U])
Xt += X
cons = (Xt | Xs[0]).pmf() == p / p0
return (P, vars, cons)
@fcn_list_to_list
def prob_symbol(*args, cname=None):
"""The probability symbol. Only used for display.
"""
x = Term.from_symbols(args, prefer_multi=False)
if cname is None:
cname = "P"
R = real(iutil.fcn_name_maker(cname, x, pname = "P", cropi = True))
return R
@fcn_list_to_list
def gacs_korner(*args, mi = None):
"""Gacs-Korner common information.
Peter Gacs and Janos Korner. Common information is far less than mutual information.
Problems of Control and Information Theory, 2(2):149-162, 1973.
e.g. gacs_korner(X & Y & Z | W)
"""
x = Term.from_symbols(args, prefer_multi=True)
U = Comp.rv("U").avoid(x)
R = real(iutil.fcn_name_maker("K", x, pname = "gacs_korner" + ("_mi" if mi is True else ""), cropi = True))
r = None
ro = None
if mi is not True:
r = universe()
for a in x.x:
r &= Expr.Hc(U, a+x.z) == 0
r &= R <= Expr.Hc(U, x.z)
if mi is not False:
ro = universe()
sumx = sum(x.x, Comp.empty())
for a in x.x:
ro &= Expr.Ic(U, sumx - a, a + x.z) == 0
ro &= R <= Expr.Ic(U, sumx, x.z)
if mi is False:
return r.add_meta("pf_note", ["def. Gacs-Korner"]).exists(U).maximum(R, None, allow_reuse = True)
elif mi is True:
return ro.add_meta("pf_note", ["def. Gacs-Korner"]).exists(U).maximum(R, None, allow_reuse = True)
else:
return ro.add_meta("pf_note", ["def. Gacs-Korner"]).exists(U).maximum(R, None,
allow_reuse = True, reg_outer = r.add_meta("pf_note", ["def. Gacs-Korner"]).exists(U))
@fcn_list_to_list
def gacs_korner_mi(*args):
"""Gacs-Korner common information, alternative characterization via mutual information.
Peter Gacs and Janos Korner. Common information is far less than mutual information.
Problems of Control and Information Theory, 2(2):149-162, 1973.
e.g. gacs_korner_mi(X & Y & Z | W)
"""
x = Term.from_symbols(args, prefer_multi=True)
U = Comp.rv("U").avoid(x)
R = real(iutil.fcn_name_maker("\\tilde{K}", x, pname = "gacs_korner_mi", cropi = True))
r = universe()
sumx = sum(x.x, Comp.empty())
for a in x.x:
r &= Expr.Ic(U, sumx - a, a + x.z) == 0
r &= R <= Expr.Ic(U, sumx, x.z)
return r.exists(U).maximum(R, None, allow_reuse = True)
@fcn_list_to_list
def wyner_ci(*args):
"""Wyner's common information.
A. D. Wyner. The common information of two dependent random variables.
IEEE Trans. Info. Theory, 21(2):163-179, 1975.
e.g. wyner_ci(X & Y & Z | W)
"""
x = Term.from_symbols(args, prefer_multi=True)
def fcncall(xdist):
xdist = xdist.flattened_sublen()
Xs = Comp.array("X", len(xdist.p.shape))
Xs = sum((X.set_card(s) for X, s in zip(Xs, xdist.p.shape)), Comp.empty())
card_bd = iutil.product(xdist.p.shape) + 1
U = Comp.rv("U").set_card(card_bd)
P, vars, cons = prob_markov_constraints(Xs, U, xdist)
return P.minimize(I(Xs & U), vars, cons)
if isinstance(x, ConcDist):
return fcncall(x)
U = Comp.rv("U").avoid(x)
R = real(iutil.fcn_name_maker("J", x, pname = "wyner_ci", cropi = True))
r = indep(*(x.x)).conditioned(U + x.z)
r &= R >= Expr.Ic(U, sum(x.x), x.z)
r = r.add_meta("pf_note", ["def. Wyner CI"]).exists(U).minimum(R, None, allow_reuse = True)
r.terms[0][0].fcncall = fcncall
r.terms[0][0].fcnargs = [x]
return r
@fcn_list_to_list
def exact_ci(*args):
"""Common entropy (one-shot exact common information).
G. R. Kumar, C. T. Li, and A. El Gamal. Exact common information. In Information
Theory (ISIT), 2014 IEEE International Symposium on, 161-165. IEEE, 2014.
e.g. exact_ci(X & Y & Z | W)
"""
x = Term.from_symbols(args, prefer_multi=True)
def fcncall(xdist):
xdist = xdist.flattened_sublen()
Xs = Comp.array("X", len(xdist.p.shape))
Xs = sum((X.set_card(s) for X, s in zip(Xs, xdist.p.shape)), Comp.empty())
tprod = iutil.product(xdist.p.shape)
card_bd = min(tprod, 2 ** (tprod // max(xdist.p.shape)) - 1)
U = Comp.rv("U").set_card(card_bd)
P, vars, cons = prob_markov_constraints(Xs, U, xdist)
return P.minimize(H(U), vars, cons)
if isinstance(x, ConcDist):
return fcncall(x)
U = Comp.rv("U").avoid(x)
R = real(iutil.fcn_name_maker("G", x, pname = "exact_ci", cropi = True))
r = indep(*(x.x)).conditioned(U + x.z)
r &= R >= Expr.Hc(U, x.z)
r = r.add_meta("pf_note", ["def. common entropy"]).exists(U).minimum(R, None, allow_reuse = True)
r.terms[0][0].fcncall = fcncall
r.terms[0][0].fcnargs = [x]
return r
@fcn_list_to_list
def H_nec(*args):
"""Necessary conditional entropy.
Cuff, P. W., Permuter, H. H., & Cover, T. M. (2010). Coordination capacity.
IEEE Transactions on Information Theory, 56(9), 4181-4206.
e.g. H_nec(X + Y | W)
"""
x = Term.from_symbols(args)
U = Comp.rv("U").avoid(x)
R = real(iutil.fcn_name_maker("Hnec", x, pname = "H_nec", lname = "H^\\dagger", cropi = True))
r = markov(x.z, U, x.x[0]) & (Expr.Hc(U, x.x[0]) == 0)
r &= R >= Expr.Hc(U, x.z)
return r.add_meta("pf_note", ["def. nec. cond. ent."]).exists(U).minimum(R, None, allow_reuse = True)
@fcn_list_to_list
def excess_fi(x, y):
"""Excess functional information.
Li, C. T., & El Gamal, A. (2018). Strong functional representation lemma and
applications to coding theorems. IEEE Trans. Info. Theory, 64(11), 6967-6978.
e.g. excess_fi(X, Y)
"""
U = Comp.rv("U").avoid(x, y)
R = real(iutil.fcn_name_maker("excess_fi", [x, y], pname = "excess_fi", lname = "\\Psi"))
r = indep(U, x)
r &= R >= Expr.Hc(y, U) - Expr.I(x, y)
return r.add_meta("pf_note", ["def. excess fcn info"]).exists(U).minimum(R, None, allow_reuse = True)
@fcn_list_to_list
def korner_graph_ent(x, y):
"""Korner graph entropy.
J. Korner, "Coding of an information source having ambiguous alphabet and the
entropy of graphs," in 6th Prague conference on information theory, 1973, pp. 411-425.
C. T. Li and A. El Gamal, "Extended Gray-Wyner system with complementary
causal side information," IEEE Transactions on Information Theory 64.8 (2017): 5862-5878.
e.g. korner_graph_ent(X, Y)
"""
U = Comp.rv("U").avoid(x, y)
R = real(iutil.fcn_name_maker("korner_graph_ent", [x, y], lname = "H_K"))
r = markov(U, x, y) & (Expr.Hc(x, y+U) == 0)
r &= R >= Expr.I(x, U)
return r.exists(U).minimum(R, None, allow_reuse = True)
@fcn_list_to_list
def perfect_privacy(x, y):
"""Perfect privacy rate.
A. Makhdoumi, S. Salamatian, N. Fawaz, and M. Medard, "From the information bottleneck
to the privacy funnel," in Information Theory Workshop (ITW), 2014 IEEE, Nov 2014, pp. 501-505.
S. Asoodeh, F. Alajaji, and T. Linder, "Notes on information-theoretic privacy," in Communication,
Control, and Computing (Allerton), 2014 52nd Annual Allerton Conference on, Sept 2014, pp. 1272-1278.
e.g. perfect_privacy(X, Y)
"""
U = Comp.rv("U").avoid(x, y)
R = real(iutil.fcn_name_maker("perfect_privacy", [x, y], lname = "g_0"))
r = markov(x, y, U) & (Expr.I(x, U) == 0)
r &= R <= Expr.I(y, U)
return r.exists(U).maximum(R, None, allow_reuse = True)
@fcn_list_to_list
def max_interaction_info(x, y):
"""Maximal interaction information.
C. T. Li and A. El Gamal, "Extended Gray-Wyner system with complementary
causal side information," IEEE Transactions on Information Theory 64.8 (2017): 5862-5878.
e.g. max_interaction_info(X, Y)
"""
U = Comp.rv("U").avoid(x, y)
R = real(iutil.fcn_name_maker("max_interaction_info", [x, y], lname = "G_{NNI}"))
r = Region.universe()
r &= R <= Expr.Ic(x, y, U) - Expr.I(x, y)
return r.exists(U).maximum(R, None, allow_reuse = True)
@fcn_list_to_list
def asymm_interaction_info(x, y):
"""Asymmetric private interaction information.
C. T. Li and A. El Gamal, "Extended Gray-Wyner system with complementary
causal side information," IEEE Transactions on Information Theory 64.8 (2017): 5862-5878.
e.g. max_interaction_info(X, Y)
"""
U = rv("U").avoid(x, y)
R = real(iutil.fcn_name_maker("asymm_interaction_info", [x, y], lname = "G_{PNI}"))
r = indep(x, U)
r &= R <= Expr.Ic(x, y, U) - Expr.I(x, y)
return r.exists(U).maximum(R, None, allow_reuse = True)
@fcn_list_to_list
def symm_interaction_info(x, y):
"""Symmetric private interaction information.
C. T. Li and A. El Gamal, "Extended Gray-Wyner system with complementary
causal side information," IEEE Transactions on Information Theory 64.8 (2017): 5862-5878.
e.g. max_interaction_info(X, Y)
"""
U = Comp.rv("U").avoid(x, y)
R = real(iutil.fcn_name_maker("symm_interaction_info", [x, y], lname = "G_{PPI}"))
r = indep(x, U) & indep(y, U)
r &= R <= Expr.Ic(x, y, U) - Expr.I(x, y)
return r.exists(U).maximum(R, None, allow_reuse = True)
@fcn_list_to_list
def minent_coupling(x, y):
"""Minimum entropy coupling of the distributions p_{Y|X=x}.
M. Vidyasagar, "A metric between probability distributions on finite sets of
different cardinalities and applications to order reduction," IEEE Transactions
on Automatic Control, vol. 57, no. 10, pp. 2464-2477, 2012.
A. Painsky, S. Rosset, and M. Feder, "Memoryless representation of Markov processes,"
in 2013 IEEE International Symposium on Information Theory. IEEE, 2013, pp. 2294-298.
M. Kovacevic, I. Stanojevic, and V. Senk, "On the entropy of couplings,"
Information and Computation, vol. 242, pp. 369-382, 2015.
M. Kocaoglu, A. G. Dimakis, S. Vishwanath, and B. Hassibi, "Entropic causal inference,"
in Thirty-First AAAI Conference on Artificial Intelligence, 2017.
F. Cicalese, L. Gargano, and U. Vaccaro, "Minimum-entropy couplings and their
applications," IEEE Transactions on Information Theory, vol. 65, no. 6, pp. 3436-3451, 2019.
Cheuk Ting Li, "Efficient Approximate Minimum Entropy Coupling of Multiple
Probability Distributions," https://arxiv.org/abs/2006.07955 , 2020.
e.g. minent_coupling(X, Y)
"""
U = Comp.rv("U").avoid(x, y)
R = real(iutil.fcn_name_maker("MEC", [x, y], pname = "minent_coupling", lname = "H_{couple}"))
r = indep(U, x) & (Expr.Hc(y, x + U) == 0)
r &= R >= Expr.H(U)
return r.exists(U).minimum(R, None, allow_reuse = True)
@fcn_list_to_list
def mutual_dep(x):
"""Mutual dependence.
Csiszar, Imre, and Prakash Narayan. "Secrecy capacities for multiple terminals."
IEEE Transactions on Information Theory 50, no. 12 (2004): 3047-3061.
"""
n = len(x.x)
if n <= 2:
return I(x)
R = real(iutil.fcn_name_maker("MD", x, pname = "mutual_dep", lname = "C_{MD}", cropi = True))
Hall = Expr.Hc(sum(x.x), x.z)
r = universe()
for part in iutil.enum_partition(n):
if len(part) <= 1:
continue
expr = Expr.zero()
for cell in part:
xb = Comp.empty()
for i in range(n):
if cell & (1 << i) != 0:
xb += x.x[i]
expr += Expr.Hc(xb, x.z)
r &= R <= (expr - Hall) / (len(part) - 1)
return r.maximum(R, None, allow_reuse = True)
@fcn_list_to_list
def intrinsic_mi(x):
"""Intrinsic mutual information.
U. Maurer and S. Wolf. "Unconditionally secure key agreement and the intrinsic
conditional information." IEEE Transactions on Information Theory 45.2 (1999): 499-514.
e.g. intrinsic_mi(X & Y | Z)
"""
U = Comp.rv("U").avoid(x)
R = real(iutil.fcn_name_maker("IMI", x, pname = "intrinsic_mi", lname = "I_{intrinsic}", cropi = True))
r = markov(sum(x.x), x.z, U) & (R >= mutual_dep(Term(x.x, U)))
return r.exists(U).minimum(R, None, allow_reuse = True)
def mi_rect(xs, ys, z = None, sgn = 1):
if z is None:
z = Comp.empty()
if not isinstance(xs, list):
xs = [xs]
if not isinstance(ys, list):
ys = [ys]
xs = [x if isinstance(x, tuple) else (x,) for x in xs]
ys = [y if isinstance(y, tuple) else (y,) for y in ys]
exprs = []
for px in itertools.product(*[range(len(x)) for x in xs]):
for py in itertools.product(*[range(len(y)) for y in ys]):
cx = sum((x[0] for x, p in zip(xs, px) if p or len(x) == 1), Comp.empty())
cy = sum((y[0] for y, p in zip(ys, py) if p or len(y) == 1), Comp.empty())
cxm = sum((x[1] for x, p in zip(xs, px) if p), Expr.zero())
cym = sum((y[1] for y, p in zip(ys, py) if p), Expr.zero())
if sgn > 0:
if cx.isempty() or cy.isempty():
if cxm.iszero() and cym.iszero():
exprs.append(Expr.zero())
else:
exprs.append(Expr.Ic(cx, cy, z) - cxm - cym)
else:
if cx.isempty() or cy.isempty():
exprs.append(-cxm - cym)
else:
exprs.append(Expr.Ic(cx, cy, z) - cxm - cym)
if len(exprs) == 1:
return exprs[0]
if sgn > 0:
return emax(*exprs)
else:
return emin(*exprs)
def mi_rect_max(xs, ys, z = None):
return mi_rect(xs, ys, z, 1)
def mi_rect_min(xs, ys, z = None):
return mi_rect(xs, ys, z, -1)
def directed_info(x, y, z = None):
"""Directed information.
Massey, James. "Causality, feedback and directed information." Proc. Int.
Symp. Inf. Theory Applic.(ISITA-90). 1990.
Parameters can be either Comp or CompArray.
"""
x = CompArray.arg_convert(x)
y = CompArray.arg_convert(y)
if z is None:
return sum(I(x.past_ns() & y | y.past()))
else:
z = CompArray.arg_convert(z)
return sum(I(x.past_ns() & y | y.past() + z.past_ns()))
def comp_vector(*args):
return CompArray.make(igen.subset(args, minsize = 1))
def comp_vector_lex(*args):
n = len(args)
return CompArray.make(sum((args[i] for i in range(n) if mask & (1 << i)), Comp.empty()) for mask in range(1, 1 << n))
def ent_vector(*args):
"""Entropy vector.
Z. Zhang and R. W. Yeung, "On characterization of entropy function via information inequalities,"
IEEE Trans. Inform. Theory, vol. 44, pp. 1440-1452, Jul 1998.
"""
if len(args) == 0:
return ExprArray.empty()
return H(comp_vector(*args))
def ent_vector_lex(*args):
"""Entropy vector in lexicographical order.
Z. Zhang and R. W. Yeung, "On characterization of entropy function via information inequalities,"
IEEE Trans. Inform. Theory, vol. 44, pp. 1440-1452, Jul 1998.
"""
if len(args) == 0:
return ExprArray.empty()
return H(comp_vector_lex(*args))
def mi_vector(*args, minsize = 1, maxsize = 1000):
"""Mutual information vector.
"""
r = ExprArray.empty()
for xs in igen.subset([1 << x for x in range(len(args))], minsize = minsize, maxsize = maxsize):
r.append(I(alland(args[x] for x in range(len(args)) if xs & (1 << x))))
return r
def ent_cells(*args, minsize = 1, maxsize = 1000):
"""Cells of the I-measure.
Z. Zhang and R. W. Yeung, "On characterization of entropy function via information inequalities,"
IEEE Trans. Inform. Theory, vol. 44, pp. 1440-1452, Jul 1998.
"""
allrv = sum(args)
r = ExprArray.empty()
for xs in igen.subset([1 << x for x in range(len(args))], minsize = minsize, maxsize = maxsize):
r.append(I(alland(args[x] for x in range(len(args)) if xs & (1 << x))
| sum(args[x] for x in range(len(args)) if not (xs & (1 << x)))))
return r
def mi_cells(*args, minsize = 2, maxsize = 1000):
"""Cells of the I-measure, excluding conditional entropies.
Z. Zhang and R. W. Yeung, "On characterization of entropy function via information inequalities,"
IEEE Trans. Inform. Theory, vol. 44, pp. 1440-1452, Jul 1998.
"""
return ent_cells(*args, minsize = minsize, maxsize = maxsize)
def ent_region(n, real_name = "R", var_name = "X", name_subset = True, st = 0):
"""Entropy region.
Z. Zhang and R. W. Yeung, "On characterization of entropy function via information inequalities,"
IEEE Trans. Inform. Theory, vol. 44, pp. 1440-1452, Jul 1998.
"""
xs = rv_seq(var_name, 0, n)
cv = comp_vector(*xs)
real_ids = []
if name_subset:
real_ids = ["".join(x) for x in igen.subset([[str(i + st)] for i in range(n)], minsize = 1, zero_elem = [])]
else:
real_ids = range(1, 1 << n)
re = real_array(real_name, real_ids)
return (re == H(cv)).exists(xs)
def csiszar_sum(*args, telescoping = False, seq_cond = False, timerv = None):
"""Csiszar sum identity.
J. Korner and K. Marton, "Images of a set via two channels and their role in multi-user communication,"
IEEE Trans. Inf. Theory, vol. 23, no. 6, pp. 751-761, 1977.
I. Csiszar and J. Korner, "Broadcast channels with confidential messages,"
IEEE Trans. Inf. Theory, vol. 24, no. 3, pp. 339-348, 1978.
"""
r = Region.universe()
vecs = []
vecs_f_mask = 0
ones = []
for a in args:
if isinstance(a, CompArray) or isinstance(a, list):
vecs.append(a)
if len(a) >= 3:
vecs_f_mask |= 1 << (len(vecs) - 1)
# ones.append(sum(a))
if timerv is not None:
for i in range(1, len(a)):
r &= H(timerv | a[i]) == 0
else:
ones.append(a)
for xmask in igen.subset_mask(vecs_f_mask):
if xmask == 0:
continue
x = sum(vecs[i] for i in range(len(vecs)) if xmask & (1 << i))
#for ymask in igen.subset_mask((1 << len(vecs)) - 1 - xmask):
for ymask in range(1, 1 << len(vecs)):
if ymask == 0:
continue
y = sum(vecs[i][0:2] for i in range(len(vecs)) if ymask & (1 << i))
vmask_all = 0
if seq_cond:
vmask_all = ((1 << len(vecs)) - 1) & ~(xmask | ymask)
for vmask in igen.subset_mask(vmask_all):
for umask in range(1 << len(ones)):
u = (sum((ones[i] for i in range(len(ones)) if umask & (1 << i)), Comp.empty())
+ sum(sum(vecs[i]) for i in range(len(vecs)) if vmask & (1 << i)))
if telescoping:
r &= Expr.Ic(x[0] + x[2], y[1], u) == Expr.Ic(x[2], y[0] + y[1], u)
else:
r &= Expr.Ic(x[2], y[0], y[1]+u) == Expr.Ic(y[1], x[0], x[2]+u)
return r.add_meta("pf_note", ["Csiszar sum"])
def ingleton_term(a1, a2, a3, a4):
"""The expression in Ingleton inequality.
A. W. Ingleton, "Representation of matroids," in Combinatorial mathematics
and its applications, D. Welsh, Ed. London: Academic Press, pp. 149-167, 1971.
"""
return -(Expr.H(a1) + Expr.H(a2) + Expr.H(a1+a2+a3) + Expr.H(a1+a2+a4) + Expr.H(a3+a4)
- Expr.H(a1+a2) - Expr.H(a1+a3) - Expr.H(a1+a4) - Expr.H(a2+a3) - Expr.H(a2+a4))
def ingleton_ineq(a1, a2, a3, a4):
"""Ingleton inequality.
A. W. Ingleton, "Representation of matroids," in Combinatorial mathematics
and its applications, D. Welsh, Ed. London: Academic Press, pp. 149-167, 1971.
"""
return (ingleton_term(a1, a2, a3, a4) >= 0).add_meta("pf_note", ["Ingleton ineq."])
def ingleton_bound(*args):
"""Bound on entropy obtained by Ingleton inequality.
A. W. Ingleton, "Representation of matroids," in Combinatorial mathematics
and its applications, D. Welsh, Ed. London: Academic Press, pp. 149-167, 1971.
"""
n = len(args)
amask = (1 << n) - 1
r = Region.universe()
for a1m in igen.subset_mask(amask):
if a1m == 0:
continue
a1 = sum((args[i] for i in range(n) if a1m & (1 << i)), Comp.empty())
for a2m in igen.subset_mask(amask - a1m):
if a2m == 0:
continue
if a2m > a1m:
break
a2 = sum((args[i] for i in range(n) if a2m & (1 << i)), Comp.empty())
for a3m in igen.subset_mask(amask - a1m - a2m):
if a3m == 0:
continue
a3 = sum((args[i] for i in range(n) if a3m & (1 << i)), Comp.empty())
for a4m in igen.subset_mask(amask - a1m - a2m - a3m):
if a4m == 0:
continue
if a4m > a3m:
break
a4 = sum((args[i] for i in range(n) if a4m & (1 << i)), Comp.empty())
r &= ingleton_ineq(a1, a2, a3, a4)
# return r.simplified()
return r
def ci_ingleton():
""" Conditional Ingleton inequalities
Milan Studeny, "Conditional independence structures over four discrete random variables
revisited: conditional Ingleton inequalities," IEEE Trans Info. Theory, 2021.
"""
disjoint_id = iutil.get_count()
A = Comp.rv("A")
B = Comp.rv("B")
C = Comp.rv("C")
D = Comp.rv("D")
A.add_markers([("disjoint", disjoint_id), ("nonempty", 1)])
B.add_markers([("disjoint", disjoint_id), ("nonempty", 1)])
C.add_markers([("disjoint", disjoint_id), ("nonempty", 1)])
D.add_markers([("disjoint", disjoint_id), ("nonempty", 1)])
ing = ingleton_ineq(C,D,A,B)
conds = [~I(A&B) & ~I(A&B|C),
~I(A&B|C) & ~I(B&D|C),
~I(A&C|D) & ~I(A&D|C),
~I(A&C|D) & ~I(C&D|A),
~I(A&C|D) & ~I(B&C|D)]
r = RegionOp.inter([])
for c in conds:
r &= (c >> ing).forall(A+B+C+D)
return r
def dfz_linear_ineq(A, B, C, D, E):
"""Linear rank inequalities in:
Dougherty, Randall, Chris Freiling, and Kenneth Zeger. "Linear rank inequalities
on five or more variables." arXiv preprint arXiv:0910.0284 (2009).
"""
return ( ( I(A&B&C) <= I(A&B|D)+I(C&D|E)+I(A&E) )
&( I(A&B&C) <= I(A&C|D)+I(A&D|E)+I(B&E) )
&( I(A&B&D) <= I(A&C)+I(B&E|C)+I(A&D|C+E) )
&( I(A&B&D+E) <= I(A&C)+I(B&D|C)+I(A&E|C+D) )
&( I(A&B&C+E) <= I(A&C)+I(B&D|C)+I(A&E|D)+I(B&C|D+E) )
&( I(A&B&C+D) <= I(A&C)+I(B&D|E)+I(D&E|C)+I(A&C|D+E) )
&( I(A&B&D+E) <= I(A&C|D)+I(A&E|C)+I(B&D)+I(B&D|C+E) )
&( I(A&B&D)+I(A&B&C) <= I(A&B|E)+I(C&D)+I(C+D&E) )
&( I(A&B&E)+I(A&B&D) <= I(A&C)+I(D&E)+I(B&D+E|C) )
&( I(A&B&D)+I(A&B&C) <= I(C&D)+I(A&E)+I(B&D|E)+I(A&C|D+E) )
&( I(A&B+C) <= I(A&C|B+D)+I(A&C+E)+I(A&B|D+E)+I(B&D|C+E) )
&( I(A&B|C) <= I(A&B|D)+I(A&D|E)+I(B&E|C)+I(A&C|B+E)+I(C&E|B+D) )
&( I(A&B+C) <= I(A&B|D)+I(A&C+E)+I(B&D|C+E)+I(A&C|B+E)+I(C&E|B+D) )
&( I(A&B+C) <= I(A&D)+I(B&E|D)+I(A&B|C+E)+I(A&C|B+D)+I(A&C|D+E) )
&( I(A&B+C) <= I(A&D)+I(B&E|D)+I(A&C|E)+I(A&B|C+D)+I(A&C|B+D)+I(B&D|C+E) )
&( I(A&B+C) <= I(A&B|C+D)+I(A&C|B+D)+I(B+C&D|E)+I(B&C|D+E)+I(A&E) )
&( I(A+B&C|D) <= I(A&D|B+C)+I(B&D|A+C)+I(A&C|B+E)+I(B&C|A+E)+I(A&B|D+E)+I(C&E|D) )
&( I(A&B&D)+I(A&C&D) <= I(B&C)+I(B&D|E)+I(C&D|E)+I(A&E) )
&( I(A&B&E)+I(A&C&D) <= I(B&D)+I(A&C|D)+I(D&E)+I(B&E|C+D)+I(C&D|B+E) )
&( I(A&B&E)+I(A&C&D) <= I(B&C)+I(B&D)+I(A&E|B)+I(C&D|E)+I(B&E|C+D) )
&( I(A&B&C+E)+I(A&C&D) <= I(B&D)+I(A&D|E)+I(C&E)+I(B&C|D+E)+I(B&E|C+D) )
&( I(A&B&D)+I(A&B&C)+I(A&C&E) <= I(C&D)+I(A&D|E)+2*I(B&E)+I(B&C|D+E)+I(C&E|B+D) )
&( I(A&B&D)+I(A&B+C) <= 2*I(A&C|E)+I(B&E)+I(D&E)+I(A&B|C+D)+2*I(B&D|C+E)+I(C&E|B+D) )
&( I(A&C+D)+I(B&C|D) <= I(B&C|E)+I(C&E|D)+I(A&E)+I(A&C|B+D)+I(A+B&D|C)+I(A&D|B+E)+I(A&B|D+E) )
).add_meta("pf_note", ["DFZ linear ineq."])
def linear_bound(*args):
"""The outer bound of the region given by linear rank inequalities. Tight for 5 or fewer random variables.
Dougherty, Randall, Chris Freiling, and Kenneth Zeger. "Linear rank inequalities
on five or more variables." arXiv preprint arXiv:0910.0284 (2009).
"""
n = len(args)
amask = (1 << n) - 1
if n <= 3:
return Region.universe()
r = ingleton_bound(*args)
if n <= 4:
return r
for a1m in igen.subset_mask(amask):
if a1m == 0:
continue
A = sum((args[i] for i in range(n) if a1m & (1 << i)), Comp.empty())
for a2m in igen.subset_mask(amask - a1m):
if a2m == 0:
continue
B = sum((args[i] for i in range(n) if a2m & (1 << i)), Comp.empty())
for a3m in igen.subset_mask(amask - a1m - a2m):
if a3m == 0:
continue
C = sum((args[i] for i in range(n) if a3m & (1 << i)), Comp.empty())
for a4m in igen.subset_mask(amask - a1m - a2m - a3m):
if a4m == 0:
continue
D = sum((args[i] for i in range(n) if a4m & (1 << i)), Comp.empty())
for a5m in igen.subset_mask(amask - a1m - a2m - a3m - a4m):
if a5m == 0:
continue
E = sum((args[i] for i in range(n) if a5m & (1 << i)), Comp.empty())
r &= dfz_linear_ineq(A, B, C, D, E)
return r
# Numerical functions
@fcn_list_to_list
def elog(x, base = None):
"""Logarithm."""
loge = PsiOpts.settings["ent_coeff"]
fcnargs = [x]
if base is not None:
fcnargs.append(base)
R = Expr.real(iutil.fcn_name_maker("log", fcnargs, pname = "elog", lname = "\\log"))
reg = Region.universe()
def fcncall(x, cbase = None):
cloge = loge
if cbase is not None:
if iutil.istorch(cbase):
cloge = 1.0 / torch.log(cbase)
else:
cloge = 1.0 / numpy.log(cbase)
if iutil.istorch(x) or iutil.istorch(cloge):
return torch.log(x) * cloge
else:
return numpy.log(x) * cloge
return Expr.fromterm(Term(R.terms[0][0].x, Comp.empty(), reg, 0, fcncall, fcnargs))
@fcn_list_to_list
def renyi(*args, order = 2):
"""Renyi entropy.
Renyi, Alfred (1961). "On measures of information and entropy". Proceedings of the
fourth Berkeley Symposium on Mathematics, Statistics and Probability 1960. pp. 547-561.
"""
ceps = PsiOpts.settings["eps"]
loge = PsiOpts.settings["ent_coeff"]
x = Term.from_symbols(args)
if isinstance(order, int):
order = float(order)
def fcncall(xdist, corder):
if isinstance(corder, ConcReal):
corder = corder.x
if isinstance(corder, int):
corder = float(corder)
if isinstance(corder, float):
if abs(order - 1) < ceps:
return xdist.entropy()
if abs(order - 0) < ceps:
r = 0.0
for a in xdist.items():
if a > ceps:
r += 1
return numpy.log(r) * loge
if numpy.isinf(order):
if xdist.istorch():
return -torch.log(torch.max(torch.flatten(xdist.p))) * loge
else:
return -numpy.log(max(xdist.items())) * loge
if xdist.istorch():
return (torch.log(torch.sum(torch.pow(torch.flatten(xdist.p), corder)))
/ (1.0 - corder) * loge)
else:
return (numpy.log(sum(numpy.power(a, corder) for a in xdist.items()))
/ (1.0 - corder) * loge)
if isinstance(x, ConcDist):
return fcncall(x, order)
R = Expr.real(iutil.fcn_name_maker("renyi", [x, order], pname = "renyi", cropi = True))
reg = Region.universe()
if isinstance(order, float):
if abs(order - 1) < ceps:
return H(x)
if order > 1:
reg = (R <= H(x)) & (R >= 0)
else:
reg = R >= H(x)
return Expr.fromterm(Term(R.terms[0][0].x, Comp.empty(), reg, 0, fcncall, [x, order]))
@fcn_list_to_list
def H0(*args):
"""Max entropy.
Renyi, Alfred (1961). "On measures of information and entropy". Proceedings of the
fourth Berkeley Symposium on Mathematics, Statistics and Probability 1960. pp. 547-561.
"""
ceps = PsiOpts.settings["eps"]
loge = PsiOpts.settings["ent_coeff"]
x = Term.from_symbols(args)
x.sort()
def fcncall(xdist):
r = 0.0
for a in xdist.items():
if a > ceps:
r += 1
return numpy.log(r) * loge
if isinstance(x, ConcDist):
return fcncall(x)
R = Expr.real(iutil.fcn_name_maker("H_0", x, pname = "H0", cropi = True))
# reg = R >= H(x)
return Expr.fromterm(Term(R.terms[0][0].x, Comp.empty(), None, 0, fcncall, [x], termtname = "H0"))
@fcn_list_to_list
def maxcorr(*args):
"""Maximal correlation.
H. O. Hirschfeld, "A connection between correlation and contingency," in Mathematical
Proceedings of the Cambridge Philosophical Society, vol. 31, no. 04. Cambridge Univ Press, 1935, pp. 520-524.
H. Gebelein, "Das statistische problem der korrelation als variations-und eigenwertproblem
und sein zusammenhang mit der ausgleichsrechnung," ZAMM-Journal of Applied Mathematics and
Mechanics/Zeitschrift fur Angewandte Mathematik und Mechanik, vol. 21, no. 6, pp. 364-379, 1941.
A. Renyi, "On measures of dependence," Acta mathematica hungarica, vol. 10, no. 3,
pp. 441-451, 1959.
"""
ceps = PsiOpts.settings["eps"]
ceps_d = PsiOpts.settings["opt_eps_denom"]
x = Term.from_symbols(args, prefer_multi=True)
def fcncall(xdist):
xdist = xdist.flattened_sublen()
if len(xdist.p.shape) != 2:
raise ValueError("Only maximal correlation between two random variables is supported.")
return None
if xdist.istorch():
px = torch.sum(xdist.p, 1)
# return torch.sum(px)
py = torch.sum(xdist.p, 0)
tmat = torch.zeros(xdist.p.shape, dtype=torch.float64)
for x in range(xdist.p.shape[0]):
for y in range(xdist.p.shape[1]):
rxy = torch.sqrt(px[x] * py[y])
tmat[x, y] = xdist.p[x, y] / (rxy + ceps_d) - rxy
return torch.linalg.norm(tmat, 2)
else:
px = numpy.sum(xdist.p, 1)
py = numpy.sum(xdist.p, 0)
tmat = numpy.zeros(xdist.p.shape)
for x in range(xdist.p.shape[0]):
for y in range(xdist.p.shape[1]):
rxy = numpy.sqrt(px[x] * py[y])
if rxy > ceps:
tmat[x, y] = xdist.p[x, y] / rxy - rxy
return numpy.linalg.norm(tmat, 2)
if isinstance(x, ConcDist):
return fcncall(x)
R = Expr.real(iutil.fcn_name_maker("maxcorr", x, pname = "maxcorr", cropi = True))
reg = R >= 0
return Expr.fromterm(Term(R.terms[0][0].x, Comp.empty(), reg, 0, fcncall, [x]))
@fcn_list_to_list
def divergence(x, y, mode = "kl"):
"""
Divergence between probability distributions.
Parameters
----------
x : Comp or ConcDist
The first distribution (if random variable is given, consider its distribution).
y : Comp or ConcDist
The second distribution (if random variable is given, consider its distribution).
mode : str, optional
Choices are "kl" (Kullback-Leibler divergence or relative entropy),
"tv" (total variation distance), "chi2" (Chi-squared divergence),
"hellinger" (Hellinger distance) and "js" (Jensen-Shannon divergence).
The default is "kl".
Returns
-------
Expr, float or torch.Tensor
The expression of the divergence. If x,y are ConcDist, gives float or torch.Tensor.
"""
mode = mode.lower()
ceps = PsiOpts.settings["eps"]
ceps_d = PsiOpts.settings["opt_eps_denom"]
loge = PsiOpts.settings["ent_coeff"]
def fcncall(xdist, ydist):
r = 0.0
if mode == "kl":
for a, b in zip(xdist.items(), ydist.items()):
r += iutil.xlogxoy(a, b) * loge
elif mode == "tv":
for a, b in zip(xdist.items(), ydist.items()):
r += abs(a - b) * 0.5
elif mode == "chi2":
if xdist.istorch() or ydist.istorch():
for a, b in zip(xdist.items(), ydist.items()):
r += (a ** 2) / (b + ceps_d)
else:
for a, b in zip(xdist.items(), ydist.items()):
r += (a ** 2) / b
r -= 1.0
elif mode == "hellinger":
for a, b in zip(xdist.items(), ydist.items()):
r += (iutil.sqrt(a) - iutil.sqrt(b)) ** 2
r = iutil.sqrt(r * 0.5)
elif mode == "js":
for a, b in zip(xdist.items(), ydist.items()):
r += iutil.xlogxoy(a, (a + b) * 0.5) * loge * 0.5
r += iutil.xlogxoy(b, (a + b) * 0.5) * loge * 0.5
return r
if isinstance(x, ConcDist) and isinstance(y, ConcDist):
return fcncall(x, y)
R = Expr.real(iutil.fcn_name_maker("divergence", [x, y, mode], pname = "divergence", cropi = True))
reg = R >= 0
return Expr.fromterm(Term(R.terms[0][0].x, Comp.empty(), reg, 0, fcncall, [x, y]))
@fcn_list_to_list
def varent(*args):
"""Varentropy (variance of self information) and dispersion (variance of information density).
Kontoyiannis, Ioannis, and Sergio Verdu. "Optimal lossless compression: Source
varentropy and dispersion." 2013 IEEE International Symposium on Information Theory. IEEE, 2013.
Polyanskiy, Yury, H. Vincent Poor, and Sergio Verdu. "Channel coding rate in the finite
blocklength regime." IEEE Transactions on Information Theory 56.5 (2010): 2307-2359.
"""
ceps = PsiOpts.settings["eps"]
ceps_d = PsiOpts.settings["opt_eps_denom"]
x = Term.from_symbols(args)
def fcncall(xdist):
xdist = xdist.flattened_sublen()
n = len(xdist.p.shape)
pxs = None
if xdist.istorch():
pxs = [torch.sum(xdist.p, tuple(j for j in range(n) if j != i)) for i in range(n)]
else:
pxs = [numpy.sum(xdist.p, tuple(j for j in range(n) if j != i)) for i in range(n)]
s1 = 0.0
s2 = 0.0
for zs in itertools.product(*[range(z) for z in xdist.p.shape]):
t = 0.0
if n == 1:
s1 += -iutil.xlogxoy(xdist.p[zs], 1.0)
s2 += iutil.xlogxoy2(xdist.p[zs], 1.0)
else:
prod = iutil.product(pxs[i][zs[i]] for i in range(n))
s1 += iutil.xlogxoy(xdist.p[zs], prod)
s2 += iutil.xlogxoy2(xdist.p[zs], prod)
return s2 - s1 ** 2
if isinstance(x, ConcDist):
return fcncall(x)
R = Expr.real(iutil.fcn_name_maker("varent", x, pname = "varent", cropi = True))
reg = R >= 0
return Expr.fromterm(Term(R.terms[0][0].x, Comp.empty(), reg, 0, fcncall, [x]))
def series_sym(x, sufp = "P", suff = "F"):
"""Series symbols (X, X_P, X_F), representing present, past and future.
"""
return CompArray.series_sym(x, sufp = sufp, suff = suff)
def main():
parser = RegionParser()
s = ""
PsiOpts.setting(str_style = "std")
while True:
try:
inputstr = "> "
if s != "":
inputstr = " "
t = input(inputstr)
except EOFError:
break
if t == "exit" or t == "quit":
break
if s != "":
s += " "
# s += "\n"
s += t
try:
res = parser.parse(s)
if res is None:
pass
elif isinstance(res, (Expr, Region)):
print(res.simplified_truth(quick = True))
elif isinstance(res, float):
print(iutil.float_tostr(res, bracket = False))
else:
print(res)
print()
s = ""
except lark.exceptions.LarkError as err:
if t == "":
print(err)
print()
s = ""
except Exception as err:
print(err)
print()
s = ""
if __name__ == '__main__':
main()