From 074ed85d0524e1dae7f0786d1d785be9b8824157 Mon Sep 17 00:00:00 2001
From: Andrea Franchini <hello@andreafranchini.com>
Date: Mon, 30 Jan 2023 20:10:23 +0100
Subject: [PATCH] wip: parser for synthesis, todo: fix strings

---
 mc_openapi/__main__.py                       |  18 +-
 mc_openapi/doml_mc/domlr_parser/grammar.lark |   2 +-
 mc_openapi/doml_mc/domlr_parser/parser.py    | 297 ++++++++++++++++---
 mc_openapi/doml_mc/domlr_parser/utils.py     |  53 +++-
 tests/domlr/example_multiple_reqs.domlr      |   9 +
 5 files changed, 335 insertions(+), 44 deletions(-)

diff --git a/mc_openapi/__main__.py b/mc_openapi/__main__.py
index b83be2b..4aa03e4 100644
--- a/mc_openapi/__main__.py
+++ b/mc_openapi/__main__.py
@@ -4,7 +4,7 @@ import argparse
 from mc_openapi.app_config import app
 from mc_openapi.doml_mc import DOMLVersion
 from mc_openapi.doml_mc.domlr_parser.exceptions import RequirementException
-from mc_openapi.doml_mc.domlr_parser.parser import Parser
+from mc_openapi.doml_mc.domlr_parser.parser import DOMLRTransformer, Parser, SynthesisDOMLRTransformer
 from mc_openapi.doml_mc.imc import RequirementStore
 from mc_openapi.doml_mc.intermediate_model.metamodel import MetaModelDocs
 from mc_openapi.doml_mc.mc import ModelChecker
@@ -68,7 +68,7 @@ else:
             user_reqs = reqsf.read()
         # Parse them
         try:
-            domlr_parser = Parser()
+            domlr_parser = Parser(DOMLRTransformer)
             user_req_store, user_req_str_consts = domlr_parser.parse(user_reqs)
         except Exception as e:
             print(e)
@@ -117,13 +117,23 @@ else:
             for k, v in  dmc.intermediate_model.items()
         }
 
-        # TODO: Fetch (user_reqs, user_reqs_strings)
+        # Parse
+        synth_domlr_parser = Parser(SynthesisDOMLRTransformer)
+        synth_user_reqs, user_reqs_strings = synth_domlr_parser.parse(user_reqs)
 
+        print(user_reqs_strings)
+        
         state = State()
         # Parse MM and IM
         state = init_data(state, doml=im, metamodel=mm)
         # Solve
-        state = solve(state, requirements=[builtin_requirements], strings=[], max_tries=args.tries)
+
+        state = solve(
+            state, 
+            requirements=[builtin_requirements, synth_user_reqs], 
+            strings=user_reqs_strings, 
+            max_tries=args.tries
+        )
         # Update state
         state = save_results(state)
         # Print output
diff --git a/mc_openapi/doml_mc/domlr_parser/grammar.lark b/mc_openapi/doml_mc/domlr_parser/grammar.lark
index 2cb5539..1e7709a 100644
--- a/mc_openapi/doml_mc/domlr_parser/grammar.lark
+++ b/mc_openapi/doml_mc/domlr_parser/grammar.lark
@@ -23,7 +23,7 @@ error_desc      : ESCAPED_STRING
                 | "(" expression ")"
                 | property
 
-?property       : CONST "has" RELATIONSHIP CONST                            -> rel_elem_expr
+?property       : CONST "has" RELATIONSHIP CONST                            -> rel_assoc_expr
                 | CONST "has" RELATIONSHIP COMPARISON_OP value              -> rel_attr_value_expr
                 | CONST "has" RELATIONSHIP COMPARISON_OP CONST RELATIONSHIP -> rel_attr_elem_expr
                 | const_or_class "is" const_or_class                        -> equality
diff --git a/mc_openapi/doml_mc/domlr_parser/parser.py b/mc_openapi/doml_mc/domlr_parser/parser.py
index ef34a54..fd46277 100644
--- a/mc_openapi/doml_mc/domlr_parser/parser.py
+++ b/mc_openapi/doml_mc/domlr_parser/parser.py
@@ -5,14 +5,14 @@ from typing import Callable
 import yaml
 from lark import Lark, Transformer, UnexpectedCharacters
 from mc_openapi.doml_mc.domlr_parser.exceptions import RequirementBadSyntaxException
-from mc_openapi.doml_mc.domlr_parser.utils import (RefHandler, StringValuesCache,
+from mc_openapi.doml_mc.domlr_parser.utils import (RefHandler, StringValuesCache, SynthesisRefHandler,
                                                  VarStore)
 from mc_openapi.doml_mc.error_desc_helper import get_user_friendly_name
 from mc_openapi.doml_mc.imc import (Requirement, RequirementStore, SMTEncoding,
                                     SMTSorts)
 from mc_openapi.doml_mc.intermediate_model import IntermediateModel
 from z3 import And, Exists, ExprRef, ForAll, Implies, Not, Or, Solver, Xor, simplify
-
+from doml_synthesis import State
 
 class ParserData:
     def __init__(self) -> None:
@@ -27,8 +27,9 @@ class ParserData:
 PARSER_DATA = ParserData()
 
 class Parser:
-    def __init__(self, grammar: str = PARSER_DATA.grammar):
+    def __init__(self, transformer, grammar: str = PARSER_DATA.grammar):
         self.parser = Lark(grammar, start="requirements")
+        self.transformer = transformer
 
     def parse(self, input: str):
         """Parse the input string containing the DOMLR requirements and
@@ -42,9 +43,26 @@ class Parser:
             const_store = VarStore()
             user_values_cache = StringValuesCache()
 
-            transformer = DOMLRTransformer(const_store, user_values_cache)
+            transformer = self.transformer(const_store, user_values_cache)
+
+            if isinstance(self.transformer, DOMLRTransformer):
+                return (
+                    RequirementStore(transformer.transform(self.tree)), 
+                    user_values_cache.get_list()
+                )
+            else:
+                reqs = transformer.transform(self.tree)
+
+                # This function has to return state or it will break the
+                # synthesis solver
+                def user_reqs(state: State):
+                    for (req, id, negated) in reqs:
+                        state.solver.assert_and_track(
+                            req(state) if not negated else Not(req(state)), f'Requirement {id}')
+                    return state
+
+                return user_reqs, user_values_cache.get_list()
 
-            return RequirementStore(transformer.transform(self.tree)), user_values_cache.get_list()
         except UnexpectedCharacters as e:
             ctx = e.get_context(input)
             msg = _get_error_desc_for_unexpected_characters(e, input)
@@ -127,38 +145,7 @@ class DOMLRTransformer(Transformer):
     def forall(self, args):
         return lambda enc, sorts: ForAll(args[0](enc, sorts), args[1](enc, sorts))
 
-    # def relationship_expr(self, args):
-    #     print(args)
-    #     rel_name = args[1].value
-
-    #     def _gen_rel_expr(enc: SMTEncoding, sorts: SMTSorts):
-    #         rel, rel_type = RefHandler.get_relationship(enc, rel_name)
-            
-    #         if rel_type == RefHandler.ASSOCIATION:
-    #             self.const_store.use(args[0].value)
-    #             self.const_store.use(args[2].value)
-
-    #             return RefHandler.get_association_rel(
-    #                 enc,
-    #                 RefHandler.get_const(args[0].value, sorts),
-    #                 rel,
-    #                 RefHandler.get_const(args[2].value, sorts)
-    #             )
-    #         elif rel_type == RefHandler.ATTRIBUTE:
-    #             self.const_store.use(args[0].value)
-
-    #             return RefHandler.get_attribute_rel(
-    #                 enc,
-    #                 RefHandler.get_const(args[0].value, sorts),
-    #                 rel,
-    #                 args[2](enc, sorts)
-    #             )
-    #         else:
-    #             raise f"Error parsing relationship {rel_name}"
-        
-    #     return _gen_rel_expr
-
-    def rel_elem_expr(self, args):
+    def rel_assoc_expr(self, args):
         """An ASSOCIATION relationship"""
         rel_name = args[1].value
         self.const_store.use(args[0].value)
@@ -337,6 +324,7 @@ class DOMLRTransformer(Transformer):
         value = args[0].value
 
         if type == "ESCAPED_STRING":
+            value = value.replace('"', '')
             self.user_values_cache.add(value)
             return lambda enc, sorts: RefHandler.get_str(value, enc, sorts), RefHandler.STRING
         elif type == "NUMBER":
@@ -377,6 +365,241 @@ class DOMLRTransformer(Transformer):
             return msg + ("\n\n\tNOTES:" + notes if notes else "")
         return err_callback
 
+class SynthesisDOMLRTransformer(Transformer):
+    # These callbacks will be called when a rule with the same name
+    # is matched. It starts from the leaves.
+    def __init__(self, 
+        const_store: VarStore, 
+        user_values_cache: StringValuesCache,
+        visit_tokens: bool = True
+    ) -> None:
+        super().__init__(visit_tokens)
+        self.const_store = const_store
+        self.user_values_cache = user_values_cache
+
+    def __default__(self, data, children, meta):
+        return children
+
+    # start
+    def requirements(self, args) -> list[tuple]:
+        # TODO: Transform Requirement into 
+        return args
+
+    def requirement(self, args) -> tuple:
+        flip_expr: bool = args[0].value == "-"
+        name: str = args[1]
+        expr: Callable[[State], ExprRef] = args[2]
+        return (
+            expr,
+            name.lower().replace(" ", "_"), # id
+            flip_expr
+        )
+
+    def req_name(self, args) -> str:
+        return str(args[0].value.replace('"', ''))
+
+    # Requirement requirement expression
+
+    def bound_consts(self, args):
+        const_names = list(map(lambda arg: arg.value, args))
+        for name in const_names:
+            self.const_store.use(name)
+            self.const_store.quantify(name)
+        return lambda state: SynthesisRefHandler.get_consts(const_names, state)
+
+    def negation(self, args):
+        return lambda state: Not(args[0](state))
+
+    def iff_expr(self, args):
+        return lambda state: args[0](state) == args[1](state)
+    
+    def implies_expr(self, args):
+        return lambda state: Implies(args[0](state), args[1](state))
+
+    def and_expr(self, args):
+        return lambda state: And([arg(state) for arg in args])
+
+    def or_expr(self, args):
+        return lambda state: Or([arg(state) for arg in args])
+
+    def exists(self, args):
+        return lambda state: Exists(args[0](state), args[1](state))
+
+    def forall(self, args):
+        return lambda state: ForAll(args[0](state), args[1](state))
+
+    def rel_assoc_expr(self, args):
+        """An ASSOCIATION relationship"""
+        rel_name = args[1].value
+        self.const_store.use(args[0].value)
+        self.const_store.use(args[2].value)
+
+        def _gen_rel_elem_expr(state: State):
+            rel = SynthesisRefHandler.get_assoc(state, rel_name)
+
+            return state.rels.AssocRel(
+                SynthesisRefHandler.get_const(args[0].value, state),
+                rel.ref,
+                SynthesisRefHandler.get_const(args[2].value, state)
+            )
+        return _gen_rel_elem_expr
+
+    def rel_attr_value_expr(self, args):
+        """An ATTRIBUTE relationship with a rhs that is a value
+        
+           CONST "has" RELATIONSHIP COMPARISON_OP value
+           0           1            2             3
+        """
+
+        rel_name = args[1].value
+        def _gen_rel_attr_value_expr(state: State):
+            elem = SynthesisRefHandler.get_const(args[0].value, state)
+            rel = SynthesisRefHandler.get_attr(state, rel_name)
+
+            rhs_value, rhs_value_type = args[3]
+            rhs_value = rhs_value(state)
+            op = args[2].value
+
+            if rhs_value_type == SynthesisRefHandler.INTEGER and rel.type == 'Integer':
+                lhs_value = state.rels.int.AttrValueRel(elem, rel.ref)
+                return And(
+                    self.compare(op, lhs_value, rhs_value),
+                    state.rels.int.AttrSynthRel(elem, rel.ref)
+                )
+            elif op != "==" and op != "!=":
+                raise "You can only use == and != to compare Strings and Booleans!"
+            elif rhs_value_type == SynthesisRefHandler.STRING:
+                lhs_value = state.rels.str.AttrValueRel(elem, rel.ref) 
+                
+                return And(
+                    lhs_value == rhs_value if op == "==" else lhs_value != rhs_value,
+                    state.rels.str.AttrSynthRel(elem, rel.ref)
+                )
+            elif rhs_value_type == SynthesisRefHandler.BOOLEAN:
+                lhs_value = state.rels.bool.AttrValueRel(elem, rel.ref)  
+                return And(
+                    lhs_value == rhs_value if op == "==" else lhs_value != rhs_value,
+                    state.rels.bool.AttrSynthRel(elem, rel.ref)
+                )
+            else:
+                raise f'Invalid value {rhs_value} during parsing for synthesis.'
+            
+
+        return _gen_rel_attr_value_expr
+
+    def rel_attr_elem_expr(self, args):
+        """An ATTRIBUTE relationship with a rhs that is another element
+           CONST "has" RELATIONSHIP COMPARISON_OP CONST RELATIONSHIP
+           0           1            2             3     4
+        """
+
+        rel1_name = args[1].value
+        rel2_name = args[4].value
+        op = args[2].value
+
+        def _gen_rel_attr_elem_expr(state: State):
+            elem1 = SynthesisRefHandler.get_const(args[0].value, state)
+            elem2 = SynthesisRefHandler.get_const(args[3].value, state)
+            rel1 = SynthesisRefHandler.get_attr(state, rel1_name)
+            rel2 = SynthesisRefHandler.get_attr(state, rel2_name)
+
+            if rel1.type == rel2.type == 'Integer':
+                return And(
+                    state.rels.int.AttrSynthRel(elem1, rel1.ref),
+                    state.rels.int.AttrSynthRel(elem2, rel2.ref),
+                    self.compare(
+                        op, 
+                        state.rels.int.AttrValueRel(elem1, rel1.ref), 
+                        state.rels.int.AttrValueRel(elem2, rel2.ref)
+                    )
+                )
+            if rel1.type == rel2.type == 'Boolean':
+                return And(
+                    state.rels.bool.AttrSynthRel(elem1, rel1.ref),
+                    state.rels.bool.AttrSynthRel(elem2, rel2.ref),
+                    self.compare(
+                        op, 
+                        state.rels.bool.AttrValueRel(elem1, rel1.ref), 
+                        state.rels.bool.AttrValueRel(elem2, rel2.ref)
+                    )
+                )
+            if rel1.type == rel2.type == 'String':
+                return And(
+                    state.rels.str.AttrSynthRel(elem1, rel1.ref),
+                    state.rels.str.AttrSynthRel(elem2, rel2.ref),
+                    self.compare(
+                        op, 
+                        state.rels.str.AttrValueRel(elem1, rel1.ref), 
+                        state.rels.str.AttrValueRel(elem2, rel2.ref)
+                    )
+                )
+            raise f'Attribute relationships {rel1_name} ({rel1.type}) and {rel2_name} ({rel2.type}) have mismatch type.'
+
+        return _gen_rel_attr_elem_expr
+
+    def _get_equality_sides(self, arg1, arg2):
+        # We track use of const in const_or_class
+        if arg1.type == "CONST" and arg2.type == "CONST":
+            return (
+                lambda state: SynthesisRefHandler.get_const(arg1.value, state),
+                lambda state: SynthesisRefHandler.get_const(arg2.value, state)
+            )
+
+        if arg1.type == "CLASS":
+            arg1_ret = lambda state: SynthesisRefHandler.get_class(state, arg1.value)
+        else:
+            arg1_ret = lambda state: SynthesisRefHandler.get_element_class(state, SynthesisRefHandler.get_const(arg1.value, state))
+
+        if arg2.type == "CLASS":
+            arg2_ret = lambda state: SynthesisRefHandler.get_class(state, arg2.value)
+        else:
+            arg2_ret = lambda state: SynthesisRefHandler.get_element_class(state, SynthesisRefHandler.get_const(arg2.value, state))
+
+        return (arg1_ret, arg2_ret)
+
+    def equality(self, args):
+        a, b = self._get_equality_sides(args[0], args[1])
+        return lambda state: a(state) == b(state)
+
+    def inequality(self, args):
+        a, b = self._get_equality_sides(args[0], args[1])
+        return lambda state: a(state) != b(state)
+
+    def const_or_class(self, args):
+        if args[0].type == "CONST":
+            self.const_store.use(args[0].value)
+        return args[0]
+    
+    def compare(self, op: str, a, b):
+
+        if op == ">":
+            return a > b
+        if op == ">=":
+            return a >= b
+        if op == "<":
+            return a < b
+        if op == "<=":
+            return a <= b
+        if op == "==":
+            return a == b
+        if op == "!=":
+            return a != b
+        raise f"Invalid Compare Operator Symbol: {op}"
+
+
+    def value(self, args):  
+        type = args[0].type
+        value = args[0].value
+
+        if type == "ESCAPED_STRING":
+            value = value.replace('"', '')
+            self.user_values_cache.add(value)
+            return lambda state: SynthesisRefHandler.get_str(value, state), SynthesisRefHandler.STRING
+        elif type == "NUMBER":
+            return lambda _: value, SynthesisRefHandler.INTEGER
+        elif type == "BOOL":
+            return lambda _: SynthesisRefHandler.get_bool(value), SynthesisRefHandler.BOOLEAN 
+
 def _get_error_desc_for_unexpected_characters(e: UnexpectedCharacters, input: str):
     # Error description
     msg = "Syntax Error:\n\n"
diff --git a/mc_openapi/doml_mc/domlr_parser/utils.py b/mc_openapi/doml_mc/domlr_parser/utils.py
index cf1c57a..ab42f7e 100644
--- a/mc_openapi/doml_mc/domlr_parser/utils.py
+++ b/mc_openapi/doml_mc/domlr_parser/utils.py
@@ -3,8 +3,8 @@ from difflib import get_close_matches
 from mc_openapi.doml_mc.domlr_parser.exceptions import \
     RequirementMissingKeyException
 from mc_openapi.doml_mc.imc import SMTEncoding, SMTSorts
-from z3 import Const, DatatypeRef, ExprRef, FuncDeclRef, SortRef, Ints
-
+from z3 import Const, DatatypeRef, ExprRef, FuncDeclRef, SortRef, Ints, And
+from doml_synthesis import State, AssocRel, AttrRel
 
 class StringValuesCache:
     def __init__(self) -> None:
@@ -108,6 +108,55 @@ class RefHandler:
     def get_attribute_rel(enc: SMTEncoding, a: ExprRef, rel: DatatypeRef, b: ExprRef) -> DatatypeRef:
         return enc.attribute_rel(a, rel, b)
 
+class SynthesisRefHandler:
+    """A utility class that provides simplified ways to create Z3 Refs.
+    To be used when parsing requirements for synthesis
+    """
+
+    INTEGER = 2
+    BOOLEAN = 3
+    STRING = 4
+
+    def get_consts(names: list[str], state: State):
+        return [Const(name, state.sorts.Elem) for name in names]
+
+    def get_const(name: str, state: State):
+        return Const(name, state.sorts.Elem)
+
+    def get_bool(value: str):
+        return value == "!True"
+
+    def get_str(value: str, state: State):
+        return state.data.Strings[value]
+
+    def get_element_class(state: State, const: ExprRef) -> FuncDeclRef:
+        return state.rels.ElemClass(const)
+
+    def get_class(state: State, class_name: str) -> DatatypeRef:
+        class_name = _convert_rel_str(class_name)
+        _class = state.data.Classes.get(class_name, None)
+        if _class is not None:
+            return _class.ref
+        else:
+            close_matches = get_close_matches(class_name, state.data.Classes.keys())
+            raise RequirementMissingKeyException("class", class_name, close_matches)
+
+    def get_assoc(state: State, rel_name: str) -> AssocRel:
+        rel_name = _convert_rel_str(rel_name)
+        rel = state.data.Assocs.get(rel_name, None)
+        if rel is not None:
+            return rel
+        else: 
+            raise f"Association {rel_name} not present in the metamodel!"
+
+    def get_attr(state: State, rel_name: str) -> AttrRel:
+            rel_name = _convert_rel_str(rel_name)
+            rel = state.data.Attrs.get(rel_name, None)
+            if rel is not None:
+                return rel
+            else: 
+                raise f"Attribute {rel_name} not present in the metamodel!"
+
 def _convert_rel_str(rel: str) -> str:
     tokens = rel.replace("abstract", "infrastructure").split(".")
     ret = tokens[0]
diff --git a/tests/domlr/example_multiple_reqs.domlr b/tests/domlr/example_multiple_reqs.domlr
index 31f6941..4518574 100644
--- a/tests/domlr/example_multiple_reqs.domlr
+++ b/tests/domlr/example_multiple_reqs.domlr
@@ -27,6 +27,15 @@
 
     error: "VM {vm} must have cpu_count >= 4 {iface}"
 
++   "VM must have os == 'rhel8'"
+    forall vm (
+        vm is class infrastructure.VirtualMachine
+        implies
+        vm has infrastructure.ComputingNode.os == "rhel8"
+    )
+
+    error: "VM {vm} must have os == 'rhel'"
+
 # OLD SYNTAX:
 # -   "Iface must be unique"
 #     ni1 has infrastructure.NetworkInterface.endPoint Value
-- 
GitLab