Skip to content
Snippets Groups Projects
Select Git revision
  • 6b31ef775b0a012701d5d889f33548b7abbc87da
  • main default
  • y3
  • y1
4 results

utils.py

Blame
  • utils.py 4.11 KiB
    from typing import cast
    from collections.abc import Sequence
    from itertools import product
    
    from z3 import (
        BoolSort,
        BoolVal,
        Context,
        Datatype,
        DatatypeSortRef,
        ExprRef,
        EnumSort,
        FuncDeclRef,
        IntSort,
        Solver,
    )
    
    from .types import Refs, SortAndRefs
    
    
    def mk_enum_sort_dict(name: str, values: list[str], z3ctx: Context) -> SortAndRefs:
        """Makes a Z3 sort and a dict indexing sort values by their name"""
    
        sort, dtrefs = EnumSort(name, values, ctx=z3ctx)
        return sort, dict(zip(values, dtrefs))
    
    
    def assert_relation_tuples(
        rel: FuncDeclRef,
        solver: Solver,
        rel_tpls: list[list[str]],
        *sig_dicts: Refs,
    ) -> None:
        """
        ### Parameters
         - `rel` is a Z3 function with return type Bool representing a relation;
         - `solver` is the Z3 solver on which to assert;
         - `rel_tpls` is a list of lists of names of symbols which are related in
           `rel`;
         - `sig_dicts` are the domains of `rel`.
    
        ### Effects
        This procedure is effectful on `solver`.
        """
    
        # Length of tuples must me homogeneous and equal to the number of given
        # domains.
        if lengths := [len(tpl) for tpl in rel_tpls]:
            assert min(lengths) == max(lengths)
            assert lengths[0] == len(sig_dicts)
    
        sym_tpls = [
            cast(
                list[ExprRef],
                [dom[sym_name] for sym_name, dom in zip(doms_tpl, sig_dicts)],
            )
            + [BoolVal(doms_tpl in rel_tpls, ctx=solver.ctx)]
            for doms_tpl in map(list, product(*sig_dicts))
        ]
    
        assert_function_tuples_raw(rel, solver, sym_tpls)
    
    
    def assert_function_tuples(
        f: FuncDeclRef,
        solver: Solver,
        f_tpls: list[list[str]],
        *sig_dicts: Refs,
    ) -> None:
        """
        ### Parameters
         - `f` is a Z3 function;
         - `solver` is the Z3 solver on which to assert;
         - `f_tpls` is a list of tuples of Z3 symbols. The first elements in each
           tuple are the inputs of `f`, and the last element is its output;
         - `sig_dicts` are the domains of `f`, the last one being its codomain.
    
        ### Effects
        This procedure is effectful on `solver`.
        """
    
        # Length of tuples must me homogeneous and equal to the number of given
        # domains.
        if lengths := [len(tpl) for tpl in f_tpls]:
            assert min(lengths) == max(lengths)
            assert lengths[0] == len(sig_dicts)
    
        sym_tpls = [
            [dom[sym_name] for sym_name, dom in zip(f_tpl, sig_dicts)]
            for f_tpl in f_tpls
        ]
    
        assert_function_tuples_raw(f, solver, sym_tpls)
    
    
    def assert_function_tuples_raw(
        f: FuncDeclRef,
        solver: Solver,
        f_tpls: Sequence[Sequence[ExprRef]],
    ) -> None:
        """
        ### Parameters
         - `f` is a Z3 function;
         - `solver` is the Z3 solver on which to assert;
         - `f_tpls` is a list of tuples of Z3 references or otherwise accepted
           values. The first elements in each tuple are the inputs of `f`, and the
           last element is its output;
    
        ### Effects
        This procedure is effectful on `solver`.
        """
        # Length of tuples must me homogeneous.
        if lengths := [len(tpl) for tpl in f_tpls]:
            assert min(lengths) == max(lengths)
    
        for *xs, y in f_tpls:
            solver.assert_and_track(
                f(*xs) == y,
                f"{f.name()} " + " ".join(str(x) for x in xs) + f" {y}",
            )
    
    
    def mk_stringsym_sort_from_strings(
        strings: list[str],
        z3ctx: Context
    ) -> SortAndRefs:
        def symbolize(s: str) -> str:
            return "".join([c.lower() if c.isalnum() else "_" for c in s[:16]])
    
        ss_list = [f"ss_{i}_{symbolize(s)}" for i, s in enumerate(strings)]
        stringsym_sort, ss_refs_dict = mk_enum_sort_dict("StringSym", ss_list, z3ctx=z3ctx)
        stringsym_sort_dict = {
            s: ss_refs_dict[ss] for s, ss in zip(strings, ss_list)
        }
        return stringsym_sort, stringsym_sort_dict
    
    
    def mk_adata_sort(
        ss_sort: DatatypeSortRef,
        z3ctx: Context
    ) -> DatatypeSortRef:
        AData = Datatype("AttributeData", ctx=z3ctx)
        AData.declare("int", ("get_int", IntSort(ctx=z3ctx)))
        AData.declare("bool", ("get_bool", BoolSort(ctx=z3ctx)))
        AData.declare("ss", ("get_ss", ss_sort))
        return AData.create()
    
    
    def Iff(a, b):
        return a == b