from collections.abc import Callable from dataclasses import dataclass from typing import Literal from z3 import (Context, DatatypeSortRef, ExprRef, FuncDeclRef, Solver, SortRef, sat) from .intermediate_model.doml_element import IntermediateModel from .mc_result import MCResult, MCResults from .z3encoding.im_encoding import (assert_im_associations, assert_im_attributes, def_elem_class_f_and_assert_classes, mk_attr_data_sort, mk_elem_sort_dict, mk_stringsym_sort_dict) from .z3encoding.metamodel_encoding import (def_association_rel, def_attribute_rel, mk_association_sort_dict, mk_attribute_sort_dict, mk_class_sort_dict) from .z3encoding.types import Refs @dataclass class SMTEncoding: classes: Refs associations: Refs attributes: Refs elements: Refs str_symbols: Refs element_class_fun: FuncDeclRef attribute_rel: FuncDeclRef association_rel: FuncDeclRef @dataclass class SMTSorts: class_sort: SortRef association_sort: SortRef attribute_sort: SortRef element_sort: SortRef str_symbols_sort: SortRef attr_data_sort: DatatypeSortRef @dataclass class Requirement: assert_callable: Callable[[SMTEncoding, SMTSorts], ExprRef] assert_name: str description: str error_description: tuple[Literal["BUILTIN", "USER"], Callable[[Solver, SMTSorts, IntermediateModel], str]] flipped: bool = False class RequirementStore: def __init__(self, requirements: list[Requirement] = []): self.requirements = requirements pass def get_all_requirements(self) -> list[Requirement]: return self.requirements def get_one_requirement(self, index: int) -> Requirement: return self.get_all_requirements()[index] def skip_requirements_by_id(self, requirement_ids: list[str]): self.requirements = [r for r in self.requirements if not r.assert_name in requirement_ids] def __len__(self): return len(self.get_all_requirements()) def __add__(self, other: "RequirementStore") -> "RequirementStore": return RequirementStore(self.requirements + other.requirements) class IntermediateModelChecker: def __init__(self, metamodel, inv_assoc, intermediate_model: IntermediateModel): self.metamodel = metamodel self.inv_assoc = inv_assoc self.intermediate_model = intermediate_model self.instantiate_solver() def instantiate_solver(self, user_string_values=[]): self.z3Context = Context() self.solver = Solver(ctx=self.z3Context) class_sort, class_ = mk_class_sort_dict(self.metamodel, self.z3Context) assoc_sort, assoc = mk_association_sort_dict( self.metamodel, self.z3Context) attr_sort, attr = mk_attribute_sort_dict( self.metamodel, self.z3Context) elem_sort, elem = mk_elem_sort_dict( self.intermediate_model, self.z3Context) str_sort, str = mk_stringsym_sort_dict( self.intermediate_model, self.metamodel, self.z3Context, user_string_values ) attr_data_sort = mk_attr_data_sort(str_sort, self.z3Context) elem_class_f = def_elem_class_f_and_assert_classes( self.intermediate_model, self.solver, elem_sort, elem, class_sort, class_ ) attr_rel = def_attribute_rel( attr_sort, elem_sort, attr_data_sort ) assert_im_attributes( attr_rel, self.solver, self.intermediate_model, self.metamodel, elem, attr_sort, attr, attr_data_sort, str ) assoc_rel = def_association_rel( assoc_sort, elem_sort ) assert_im_associations( assoc_rel, self.solver, {k: v for k, v in self.intermediate_model.items()}, elem, assoc_sort, assoc, ) self.smt_encoding = SMTEncoding( class_, assoc, attr, elem, str, elem_class_f, attr_rel, assoc_rel ) self.smt_sorts = SMTSorts( class_sort, assoc_sort, attr_sort, elem_sort, str_sort, attr_data_sort ) def check_requirements(self, reqs: RequirementStore, timeout: int = 0) -> MCResults: self.solver.set(timeout=(timeout * 1000)) results = [] for req in reqs.get_all_requirements(): self.solver.push() self.solver.assert_and_track( req.assert_callable(self.smt_encoding, self.smt_sorts), req.assert_name ) res = self.solver.check() req_type, req_err_desc_fn = req.error_description req_is_sat = MCResult.from_z3result(res, flipped=req.flipped) results.append(( req_is_sat, req_type, req_err_desc_fn(self.solver, self.smt_sorts, self.intermediate_model) if req_is_sat else "" # if res == sat else "" # not needed since we're try/catching model() errors # in each requirement now )) self.solver.pop() # stats = self.solver.statistics() return MCResults(results)