"""
Implements the network representation class.
"""
from __future__ import annotations
import collections
import math
import random
import logging
from copy import deepcopy
from itertools import tee, filterfalse, islice
from queue import PriorityQueue
from typing import Union, Callable, Iterable, List, Tuple, Dict
from xml.etree import ElementTree as ET
import geojson
from pyproj import Transformer


logger = logging.getLogger(__name__)


class Network:
    def __init__(self, city, path):
        """
        Constructs a network object by reading from file.

        :param path: path to the network.xml file
        :param type: string
        :param crs_epsg: the coordinate reference system used.
            Network files used by MATSim use CRS epsg 2062.
        :param type: string
        """

        self.city = city

        # used for importing counting locations, which use a different CRS
        if city == "bilbao":
            self.crs_epsg = "32630"
        elif city == "amsterdam":
            self.crs_epsg = "32631"
        elif city == "messina":
            self.crs_epsg = "32633"
        elif city == "helsinki":
            self.crs_epsg = "32634"

        self._element_tree: Union[ET.ElementTree, None] = None
        self.count_locations = []  # used to load traffic counts
        self.bounds = {
            "x": {"low": math.inf, "high": 0.0},
            "y": {"low": math.inf, "high": 0.0},
        }  # used to create quad tree

        self.read_network(path)
        self.update_network()

    def read_network(self, path) -> None:
        """
        Reads network file from path and builds search trees.

        :param path: path to the network file
        :param type: string
        """
        self._element_tree = ET.parse(path)
        root = self._element_tree.getroot()

        nodes = []
        links = []

        for child in root:
            if child.tag == "nodes":
                for node in child:
                    nodes.append(node.attrib)
                    self.bounds["x"]["low"] = min(
                        float(node.attrib["x"]), self.bounds["x"]["low"]
                    )
                    self.bounds["x"]["high"] = max(
                        float(node.attrib["x"]), self.bounds["x"]["high"]
                    )
                    self.bounds["y"]["low"] = min(
                        float(node.attrib["y"]), self.bounds["y"]["low"]
                    )
                    self.bounds["y"]["high"] = max(
                        float(node.attrib["y"]), self.bounds["y"]["high"]
                    )

            elif child.tag == "links":
                for link in child:
                    links.append(link.attrib)
        self.nodes = nodes
        self.links = links

    def _make_link_id_lut(self) -> dict:
        """
        Creates a look-up-table of links for faster access.
        """
        lut = {}
        for i, link in enumerate(self.links):
            key = link["id"][: self.link_id_lut_key_len]
            if key in lut.keys():
                lut[key].append(i)
            else:
                lut[key] = [i]
        return lut

    # noinspection DuplicatedCode
    @staticmethod
    def _make_kdtree_nodes(
        nodes: List[dict],
        low_x: float,
        high_x: float,
        low_y: float,
        high_y: float,
        axis: str,
        depth=0,
        max_depth=None,
        max_leaf_items=1,
    ) -> Union[list, dict]:
        """
        Build the kd-tree of nodes for fast spatial search.
        """
        # print(depth)
        if axis not in "xy":
            raise ValueError(f"Illegal axis {axis}, should be either 'x' or 'y'.")

        if len(nodes) < 1:
            raise RuntimeWarning(
                "_make_kdtree_nodes did illegal recurse into empty subtree."
            )

        if len(nodes) <= max_leaf_items:
            return list(map(lambda node: node["id"], nodes))

        if max_depth is not None and depth >= max_depth:
            print(f"WARNING: max_depth {max_depth} reached.")
            return list(map(lambda node: node["id"], nodes))

        k = 10
        sample: List[Dict]
        if len(nodes) > k:
            sample = random.sample(nodes, k)
        else:
            sample = nodes
        mid = sum([float(node[axis]) for node in sample]) / len(
            sample
        )  # pylint: disable=locally-disabled, unsubscriptable-object

        high_nodes, low_nodes = Network._partition(
            lambda node: float(node[axis]) <= mid, nodes
        )

        if len(high_nodes) == 0:
            return low_nodes
        if len(low_nodes) == 0:
            return high_nodes

        if axis == "x":
            return {
                "axis": axis,
                "low": low_x,
                "mid": mid,
                "high": high_x,
                "low_sub": Network._make_kdtree_nodes(
                    low_nodes, low_x, mid, low_y, high_y, "y", depth=depth + 1
                ),
                "high_sub": Network._make_kdtree_nodes(
                    high_nodes, mid, high_x, low_y, high_y, "y", depth=depth + 1
                ),
            }
        else:
            return {
                "axis": axis,
                "low": low_y,
                "mid": mid,
                "high": high_y,
                "low_sub": Network._make_kdtree_nodes(
                    low_nodes, low_x, high_x, low_y, mid, "x", depth=depth + 1
                ),
                "high_sub": Network._make_kdtree_nodes(
                    high_nodes, low_x, high_x, mid, high_y, "x", depth=depth + 1
                ),
            }

    # noinspection DuplicatedCode
    def _make_kdtree_links(
        self,
        links: List[Dict],
        low_x: float,
        high_x: float,
        low_y: float,
        high_y: float,
        axis: str,
        depth=0,
        max_depth=None,
        max_leaf_items=1,
    ) -> Union[list, dict]:
        """
        Build the kd-tree of links for faster spatial search.
        """

        if axis not in "xy":
            raise ValueError(f"Illegal axis {axis}, should be either 'x' or 'y'.")

        if len(links) < 1:
            raise RuntimeWarning(
                "_make_kdtree_links did illegal recurse into empty subtree."
            )

        if len(links) <= max_leaf_items:
            ret = list(map(lambda link: link["id"], links))
            return ret

        if max_depth is not None and depth >= max_depth:
            print(f"WARNING: max_depth {max_depth} reached.")
            return list(map(lambda link: link["id"], links))

        k = 10
        sample: List[Dict]
        if len(links) > k:
            sample = random.sample(links, k)
        else:
            sample = links

        mid = sum(
            [
                float(
                    self.get_node(link["from"])[
                        axis
                    ]  # pylint: disable=locally-disabled, unsubscriptable-object
                )
                + float(
                    self.get_node(link["to"])[
                        axis
                    ]  # pylint: disable=locally-disabled, unsubscriptable-object
                )
                for link in sample
            ]
        ) / (2 * len(sample))

        high_links, low_links = Network._partition(
            lambda link: (
                (
                    float(self.get_node(link["from"])[axis])
                    + float(self.get_node(link["to"])[axis])
                )
                / 2
            )
            <= mid,
            links,
        )

        # high_links_axis = [((float(self.get_node(link['from'])[axis]) + float(self.get_node(link['to'])[axis])) / 2) for link in high_links]
        # low_links_axis = [((float(self.get_node(link['from'])[axis]) + float(self.get_node(link['to'])[axis])) / 2) for link in low_links]

        # this happens if the links are the "same" (From nodeA to nodeB), but different directions
        if len(high_links) == 0:
            return low_links
        if len(low_links) == 0:
            return high_links

        # recursion
        if axis == "x":
            return {
                "axis": axis,
                "low": low_x,
                "mid": mid,
                "high": high_x,
                "low_sub": self._make_kdtree_links(
                    low_links, low_x, mid, low_y, high_y, "y", depth + 1
                ),
                "high_sub": self._make_kdtree_links(
                    high_links, mid, high_x, low_y, high_y, "y", depth + 1
                ),
            }
        else:
            return {
                "axis": axis,
                "low": low_y,
                "mid": mid,
                "high": high_y,
                "low_sub": self._make_kdtree_links(
                    low_links, low_x, high_x, low_y, mid, "x", depth + 1
                ),
                "high_sub": self._make_kdtree_links(
                    high_links, low_x, high_x, mid, high_y, "x", depth + 1
                ),
            }

    @staticmethod
    def _flatten_kd(kdtree) -> list:
        """
        Flattens kd (sub) tree to a list.

        :param kdtree: kd tree or subtree
        :param type: list
        :returns: list
        """
        if isinstance(kdtree, list):
            return kdtree
        return Network._flatten_kd(kdtree["low_sub"]) + Network._flatten_kd(
            kdtree["high_sub"]
        )

    @staticmethod
    def _xml_remove_mode(mode_to_remove: str, modes: str) -> str:
        """
        Removes travel mode from string containing allowed travel modes for link.

        :param mode_to_remove: "bike", "car", or "pt"
        :param type: string
        :modes: string of allowed travel modes from the link attribute "modes".
        :param type: string
        :returns: string with the mode_to_remove remvoed.
        """
        if mode_to_remove == modes:
            raise ValueError("Cannot remove only mode. Consider deleting link instead.")
        prefix, _, postfix = str.partition(modes, mode_to_remove)
        if prefix.endswith(",") and postfix.startswith(","):
            prefix = prefix[:-1]
        return prefix + postfix

    @staticmethod
    def _partition(pred: Callable, itera: Iterable) -> Tuple[list, list]:
        """Use a predicate to partition entries into false entries and true entries."""
        t1, t2 = tee(itera)
        return list(filterfalse(pred, t1)), list(filter(pred, t2))

    @staticmethod
    def _sliding_window(a_list: Iterable, window_size=2):
        """
        Implements a sliding windows over a list.
        """
        iterator = iter(a_list)
        window = collections.deque(islice(iterator, window_size), maxlen=window_size)
        if len(window) == window_size:
            yield list(window)
        for x in iterator:
            window.append(x)
            yield list(window)

    @staticmethod
    def _index(item_id, array) -> int:
        """
        Implements binary search.
        :param item_id: item of interest
        :param array: array in which the item is
        :return: index of the item in the array: int
        """
        low = 0
        high = len(array) - 1
        mid = 0
        while low <= high:
            mid = (high + low) // 2
            # node is lexically greater, ignore left half

            if array[mid] < item_id:
                low = mid + 1

            # node is lexically smaller, ignore right half
            elif array[mid] > item_id:
                high = mid - 1
            else:
                return mid
        return -1

    def make_line_string(self, link_id, crs_transformer) -> geojson.LineString:
        """
        Returns a geojson LineString object of the link.

        :param link_id: link id string of the link to encode.
        :param crs_transformer: crs transformer object applied to the coordinates.
        :returns: geojson.LineString object of the link
        """
        return geojson.LineString(
            [
                (
                    crs_transformer.transform(
                        float(self.get_node(self.links[link_id]["from"])["x"]),
                        float(self.get_node(self.links[link_id]["from"])["y"]),
                    )
                ),
                (
                    crs_transformer.transform(
                        float(self.get_node(self.links[link_id]["to"])["x"]),
                        float(self.get_node(self.links[link_id]["to"])["y"]),
                    )
                ),
            ]
        )

    def get_node(self, node_id: str) -> dict:
        """
        Fast node retrieval by id.
        """
        index = self._index(node_id, self.node_ids)
        return self.nodes[index]

    def get_link(self, link_id: str) -> dict:
        """
        Fast link retrieval by id.
        """
        for link_index in self.link_id_lut[link_id[: self.link_id_lut_key_len]]:
            if self.links[link_index]["id"] == link_id:
                return self.links[link_index]

    @staticmethod
    def average_point(points: List[dict]) -> dict:
        """
        Returns the average coordinates of the points provided.
        """
        return {
            "x": sum([p["x"] for p in points]) / len(points),
            "y": sum([p["y"] for p in points]) / len(points),
        }

    @staticmethod
    def distance_point_point(point1, point2) -> float:
        """
        Find the euclidean distance from point1 to point2.
        :param point1: dict{'x': float, 'y': float}
        :param point2: dict{'x': float, 'y': float}
        :return: distance: float
        """
        try:
            if "pt" not in point1["id"]:
                # logger.warn(point1)
                pass
        except:
            pass
        try:
            if "pt" not in point2["id"]:
                # logger.warn(point2)
                pass
        except:
            pass

        dist = math.sqrt(
            (float(point1["x"]) - float(point2["x"])) ** 2
            + (float(point1["y"]) - float(point2["y"])) ** 2
        )
        return dist

    def distance_point_node(self, point: dict, node_id: str) -> float:
        """
        Find the distance from a node to a point.

        :param node_id: id of the node: str
        :param point: point: dict{'x': float, 'y': float}
        :return: distance: float
        """
        node = self.get_node(node_id)
        return self.distance_point_point(node, point)

    def distance_point_link(self, point: dict, link: str) -> float:
        """
        Shortest distance between a point and a link.

        :param node_id: id of the node: str
        :param point: point: dict{'x': float, 'y': float}
        :return: distance: float
        """
        from_node = self.get_node(link["from"])
        to_node = self.get_node(link["to"])

        from_point = {"x": float(from_node["x"]), "y": float(from_node["y"])}
        to_point = {"x": float(to_node["x"]), "y": float(to_node["y"])}
        link_length_sqr = (to_point["x"] - from_point["x"]) ** 2 + (
            to_point["y"] - from_point["y"]
        ) ** 2

        # links with zero length often have non-zero length due to shape simplification done by matsim
        if link_length_sqr == 0:
            return self.distance_point_point(point, from_point)

        # project onto line segment
        dot = (float(point["x"]) - float(from_point["x"])) * (
            float(to_point["x"]) - float(from_point["x"])
        ) + (float(point["y"]) - float(from_point["y"])) * (
            float(to_point["y"]) - float(from_point["y"])
        )

        t = max(0, min(1, dot / link_length_sqr))
        projected_point = {
            "x": from_point["x"] + t * (to_point["x"] - from_point["x"]),
            "y": from_point["y"] + t * (to_point["y"] - from_point["y"]),
        }

        distance = self.distance_point_point(point, projected_point)
        return distance

    def get_link_length(self, from_id, to_id) -> float:
        """
        Find the distance of the link from one node to another.
        This may not be equal to distance between both nodes
        due to simplification of the network.
        The correct length of the link is an attribute of the link.

        :param from_id: id of the first node: str
        :param to_id:  if od the second node: str
        :return: distance of the link: float
        """
        index = self._index(from_id + "-" + to_id, self.link_from_to_ids)
        link = self.links[index]
        return float(link["length"])

    def get_link_length_link(self, link_id) -> float:
        """
        Find the distance of the link from one node to another.
        This may not be equal to distance between both nodes
        due to simplification of the network.
        The correct length of the link is an attribute of the link.

        :param from_id: id of the first node: str
        :param to_id:  if od the second node: str
        :return: distance of the link: float
        """
        link = self.get_link(link_id)
        return float(link["length"])

    def search_kdtree_range(
        self, point: dict, kdtree=None, delta: float = 10
    ) -> List[str]:
        """
        Returns ids of nodes within delta distance from the point.
        """

        if kdtree is None:
            kdtree = self.node_kd_tree
        results = []

        # leaf node, not excluded, all in range
        if isinstance(kdtree, list):
            for idx, link in enumerate(kdtree):
                if not isinstance(link, str):
                    kdtree[idx] = link["id"]
            return kdtree

        # drop this subtree
        if not (
            kdtree["low"] - delta <= point[kdtree["axis"]] < kdtree["mid"] + delta
            or kdtree["mid"] - delta < point[kdtree["axis"]] <= kdtree["high"] + delta
        ):
            return []

        # recurse into left subtree
        if kdtree["low"] - delta <= point[kdtree["axis"]] < kdtree["mid"] + delta:
            results += self.search_kdtree_range(point, kdtree["low_sub"], delta)

        # recurse into right subtree
        if kdtree["mid"] - delta < point[kdtree["axis"]] <= kdtree["high"] + delta:
            results += self.search_kdtree_range(point, kdtree["high_sub"], delta)

        return results

    def get_neighbour_node_ids(self, node_id: str) -> List[str]:
        """
        Find neighbouring nodes of a node.
        :param node_id: id of the reference node.
        :return: list of node ids: list[str]
        """
        mid_index = self._index(node_id, self.link_from_ids)
        neighbour_ids: list[str] = []
        # search to the left
        for i in range(mid_index, -1, -1):
            link = self.links[i]
            if link["from"] == node_id:
                neighbour_ids.append(link["to"])
            else:
                break
        # search to the right
        for i in range(mid_index, len(self.links), 1):
            link = self.links[i]
            if link["from"] == node_id:
                neighbour_ids.append(link["to"])
            else:
                break
        neighbour_ids = sorted(neighbour_ids, key=lambda link: link[0])
        return neighbour_ids

    def nearby_links_range(self, point: Dict[str, float], delta: float) -> List[str]:
        """
        Returns links within delta distance of point.
        """
        if isinstance(point, str):
            node = self.get_node(point)
            point = {
                "x": float(node["x"]),
                "y": float(node["y"])
            }
        candidates = self.search_kdtree_range(
            point, kdtree=self.link_kd_tree, delta=delta
        )
        return candidates

    def get_links_to_node(self, node_id):
        """
        Returns links going into node (input).
        """
        to_links_indexes = [i for i, x in enumerate(self.link_from_ids) if x == node_id]
        to_links = []
        for idx in to_links_indexes:
            to_links.append(self.links[idx])
        return to_links

    def get_links_from_node(self, node_id):
        """
        Returns links going out of node (output).
        """
        from_links_indexes = [i for i, x in enumerate(self.link_to_ids) if x == node_id]
        from_links = []
        for idx in from_links_indexes:
            from_links.append(self.links[idx])
        return from_links

    def nearby_nodes_range(self, point: dict, delta: float = 10) -> List[str]:
        """
        Returns nodes within delta distance of point.
        """
        candidates = self.search_kdtree_range(point, delta=delta)
        return candidates

    def random_node_near(self, max_dist=100, point=None, node_id=None) -> dict:
        """
        Randomly select a node near a passed point or node.
        Exactly one of params point and node_id should be passed.
        :param max_dist: maximum distance to consider near in meters, default 100
        :param point: the point to measure distance from
        :param node_id: the node to measure distance from
        :return: a node within max_dist of passed point or node.
        """
        if not point and not node_id:
            raise ValueError("Either a point or a node_id must be specified.")
        if point and node_id:
            print(
                "WARNING: Only one of point or node_id should be specified. Defaulting to point."
            )
        if node_id and not point:
            node = self.get_node(node_id)
            point = {"x": node["x"], "y": node["y"]}

        nearby_nodes = list(
            filter(
                lambda node: self.distance_point_node(point, node["id"]) < max_dist,
                self.nodes,
            )
        )
        if len(nearby_nodes) < 1:
            raise RuntimeError(f"No nodes found within {max_dist}m of point {point}")
        return nearby_nodes[random.randrange(0, len(nearby_nodes))]

    def find_path(
        self, start_id: str, end_id: str, heuristic: Callable = None
    ) -> List[str]:
        """
        Finds the cheapest path from start node to end node according to the heuristic.
        If heuristic is not passed, euclidean distance is used.
        :param start_id: id of the start node
        :param end_id: id of the end node
        :param heuristic: function node_id:str->cost:float
        :return: list of node ids
        """
        open_ = PriorityQueue()
        open_.put((0, start_id))

        g_ = {start_id: 0}

        came_from_ = {start_id: start_id}

        if heuristic is None:

            def h_(node_id):
                # euclidean distance to end node
                end_ = self.get_node(end_id)
                node = self.get_node(node_id)
                return (float(end_["x"]) - float(node["x"])) ** 2 + (
                    float(end_["y"]) - float(node["y"])
                ) ** 2

        else:
            h_ = heuristic

        f_ = {start_id: 0 + h_(start_id)}

        def reconstruct_path(node):
            path = [node]
            while came_from_[node] != node:
                path = [came_from_[node]] + path
                node = came_from_[node]
            return path

        def is_in_open(node_id, open=open_):
            q = open.queue
            for _, v in q:
                if v == node_id:
                    return True
            return False

        while not open_.empty():
            _, current_node_id = open_.get()
            if current_node_id == end_id:
                return reconstruct_path(end_id)

            for neighbour_id in self.get_neighbour_node_ids(current_node_id):
                tentative_g_score = g_[current_node_id] + self.get_link_length(
                    current_node_id, neighbour_id
                )
                if tentative_g_score < g_.get(neighbour_id, math.inf):
                    came_from_[neighbour_id] = current_node_id
                    g_[neighbour_id] = tentative_g_score
                    f_[neighbour_id] = tentative_g_score + h_(neighbour_id)

                    if not is_in_open(neighbour_id):
                        open_.put((f_[neighbour_id], neighbour_id))

    def points_to_nodes(
        self, points, with_dist=False
    ) -> Union[List[str], Tuple[List[str], List[float]]]:
        """
        For each point, find the best matching link w.r.t. location.
        :param points: list[dict{'x': float, 'y': float}]
        :param with_dist: also return the distance to each point
        :return: list[link_id: str], (list[distance: float] - if with_dist==True)
        """
        results = []
        for point in points:
            node_distances = []
            for node in self.nodes:
                node_distances.append(
                    (self.distance_point_node(point, node["id"]), node["id"])
                )
            node_distances = sorted(node_distances, key=lambda node: node[0])
            results.append(node_distances[0])
        if with_dist:
            return (
                list(map(lambda node: node[1], results)),
                list(map(lambda node: node[0], results)),
            )
        return list(map(lambda node: node[1], results))

    def points_to_links(
        self, points, with_dist=False
    ) -> Union[List[str], Tuple[List[str], List[float]]]:
        """
        For each point, find the best matching link w.r.t. location.
        :param points: list[dict{'x': float, 'y': float}]
        :param with_dist: also return the distance to each point
        :return: list[link_id: str], (list[distance: float] - if with_dist==True)
        """
        results = []
        for point in points:
            link_distances = []
            # TEST

            print("DEBUG: point is ", point)
            for link in self.nearby_links_range(point, delta=30):
                link_distances.append(
                    (self.distance_point_link(point, link), link["id"])
                )
            link_distances = sorted(link_distances, key=lambda link: link[0])
            results.append(link_distances[0])
        if with_dist:
            return (
                list(map(lambda link: link[1], results)),
                list(map(lambda link: link[0], results)),
            )
        return list(map(lambda link: link[1], results))

    def get_max_node_link_id(self):
        max_node_id = 0
        max_link_id = 0
        for node in self.nodes:
            if node["id"].find("pt") == -1:
                max_node_id = max(int(node["id"]), max_node_id)
        max_node_id += 1

        for link in self.links:
            if link["id"].find("pt") == -1:
                max_link_id = max(int(link["id"]), max_link_id)
        max_link_id += 1

        return max_node_id, max_link_id

    def links_minimal_values(self, links):
        """
        The function returns minimal capaciry, freesped, permalanes and oneway
        attributes of all the links we passed to the funtion
        """
        # Find entry/exit bike_links attribute values
        bike_lane_capacity = 10000000.0
        bike_lane_freespeed = 10000000.0

        # We find min values of all links and set them
        bike_lane_found = True if len(links) > 0 else False
        for link in filter(lambda link: "bicycle" in link["modes"], links):
            bike_lane_capacity = min(bike_lane_capacity, float(link["capacity"]))
            bike_lane_freespeed = min(bike_lane_freespeed, float(link["freespeed"]))
        # If there is no bike lane to look at, set default values for bikes --> approximations from network.xml file
        if not bike_lane_found:
            bike_lane_capacity = 300.0
            bike_lane_freespeed = 8.333333333333334
        return (
            bike_lane_capacity,
            bike_lane_freespeed,
        )

    def generate_links(self, nodes, max_link_id, capacity, freespeed):
        """
        Generates new links
        """
        links = []
        for i in range(len(nodes) - 1):
            link = {}
            link["id"] = str(max_link_id)
            max_link_id += 1
            link["from"] = nodes[i]["id"]
            link["to"] = nodes[i + 1]["id"]
            link["length"] = self.distance_point_point(nodes[i], nodes[i + 1])
            # link["oneway"] = 1
            link["modes"] = "bicycle"
            link["capacity"] = capacity
            link["freespeed"] = freespeed
            link["permlanes"] = 1.0
            link["oneway"] = 1
            links.append(link)

        return links, max_link_id

    def add_bike_lane(self, links, two_way=False, lane_separation=1.5):
        # find max node and link ids
        max_node_id, max_link_id = self.get_max_node_link_id()

        new_links = []
        new_nodes = []

        crs_transformer = Transformer.from_crs("epsg:4326", f"{self.crs_epsg}")
        # generate nodes
        nodes = []
        coordinates = links["features"][0]["geometry"]["coordinates"]
        for point in coordinates:
            # logger.warning(point)
            node = {}
            # x and y have to be swapped for some reason
            node["x"], node["y"] = crs_transformer.transform(
                point[1], point[0]
            )  # pylint: disable=locally-disabled, unpacking-non-sequence
            node["id"] = str(max_node_id)
            max_node_id += 1
            nodes.append(node)
            new_nodes.append(node)
            self.nodes.append(node)
            self.bounds["x"]["low"] = min(self.bounds["x"]["low"], node["x"])
            self.bounds["x"]["high"] = max(self.bounds["x"]["high"], node["x"])
            self.bounds["y"]["low"] = min(self.bounds["y"]["low"], node["y"])
            self.bounds["y"]["high"] = max(self.bounds["y"]["high"], node["y"])

        # Select entry node
        nearby_node_range_ = self.nearby_nodes_range(nodes[0], 10)
        nearby_nodes = [self.get_node(n) for n in nearby_node_range_]
        entry_node = nearby_nodes[0]
        cur_min_distance = self.distance_point_point(nodes[0], entry_node)
        for node in nearby_nodes:
            if (
                self.distance_point_point(nodes[0], node) < cur_min_distance
                and node["id"] in self.link_from_ids
                and node["id"] in self.link_to_ids
            ):
                cur_min_distance = self.distance_point_point(nodes[0], node)
                entry_node = node
        nodes.insert(0, entry_node)

        # Select exit node
        nearby_nodes = [
            self.get_node(n) for n in self.nearby_nodes_range(nodes[-1], 10)
        ]
        exit_node = nearby_nodes[0]
        cur_min_distance = self.distance_point_point(nodes[-1], exit_node)
        for node in nearby_nodes:
            if (
                self.distance_point_point(nodes[-1], node) < cur_min_distance
                and node["id"] in self.link_from_ids
                and node["id"] in self.link_to_ids
            ):
                cur_min_distance = self.distance_point_point(nodes[-1], node)
                exit_node = node
        nodes.append(exit_node)
        # Check entry and exit links for bike lanes and adjust default values if you find something
        entry_node_id = entry_node["id"]
        entry_output_links = self.get_links_to_node(entry_node_id)
        exit_node_id = exit_node["id"]
        exit_input_links = self.get_links_from_node(exit_node_id)
        # Get minmal values for the links
        all_links_one_way = entry_output_links + exit_input_links
        capacity, freespeed = self.links_minimal_values(all_links_one_way)
        # generate the links one way
        generated_links, max_link_id = self.generate_links(
            nodes, max_link_id, capacity, freespeed
        )
        self.links += generated_links
        new_links += generated_links
        # Creating two direction bikelanes
        if two_way:
            # reverse array of nodes and create a temp copy for offset calculation
            nodes.reverse()
            points = deepcopy(nodes)
            for _, i in enumerate(points):
                # calculate offset for all nodes except entry and exit
                if i - 1 >= 0 and i + 1 < len(points):
                    dx = (
                        float(points[i - 1]["x"])
                        + float(points[i + 1]["x"])
                        - 2 * float(points[i]["x"])
                    )
                    dy = (
                        float(points[i - 1]["y"])
                        + float(points[i + 1]["y"])
                        - 2 * float(points[i]["y"])
                    )
                    l = math.sqrt(dx**2 + dy**2)
                    dx = dx / l * lane_separation
                    dy = dy / l * lane_separation
                    nodes[i]["x"] += dx
                    nodes[i]["y"] += dy
                    nodes[i]["id"] = str(max_node_id)
                    max_node_id += 1
                    new_nodes.append(nodes[i])
                    self.nodes.append(nodes[i])
                    self.bounds["x"]["low"] = min(
                        self.bounds["x"]["low"], nodes[i]["x"]
                    )
                    self.bounds["x"]["high"] = max(
                        self.bounds["x"]["high"], nodes[i]["x"]
                    )
                    self.bounds["y"]["low"] = min(
                        self.bounds["y"]["low"], nodes[i]["y"]
                    )
                    self.bounds["y"]["high"] = max(
                        self.bounds["y"]["high"], nodes[i]["y"]
                    )
                # generate links for two way
                generated_links, max_link_id = self.generate_links(
                    nodes, max_link_id, capacity, freespeed
                )
                self.links += generated_links
                new_links += generated_links
        self.update_network()
        for child in self._element_tree.getroot():
            if child.tag == "nodes":
                for node in new_nodes:
                    child.append(ET.Element("node", {k: str(node[k]) for k in node}))
            if child.tag == "links":
                for link in new_links:
                    el_link = ET.Element("link", {k: str(link[k]) for k in link})
                    el_attributes = ET.Element("attributes")
                    el_attribute = ET.Element("attribute")
                    el_attribute.set("name", "osm:way:highway")
                    el_attribute.set("class", "java.lang.String")
                    el_attribute.text = "cycleway"
                    el_attributes.append(el_attribute)
                    el_link.append(el_attributes)
                    child.append(el_link)
        # TODO replace 1 with network id, create network in DB
        with open(f"data/{self.city}/original_input/networks/1/network.xml", "wb") as f:
            f.write(
                b'<?xml version="1.0" encoding="UTF-8"?>\n<!DOCTYPE network SYSTEM "http://www.matsim.org/files/dtd/network_v2.dtd">\n'
            )
            self._element_tree.write(f)
        return new_links

    def update_network(self):
        # sys.setrecursionlimit(10000)
        self.nodes = sorted(self.nodes, key=lambda node: node["id"])
        self.links = sorted(
            self.links, key=lambda link: link["from"] + "-" + link["to"]
        )

        # following data structures are used to speed up algorithms.
        self.link_id_lut_key_len = 5
        self.link_id_lut = self._make_link_id_lut()
        self.node_ids = [node["id"] for node in self.nodes]
        self.link_from_to_ids = [link["from"] + "-" + link["to"] for link in self.links]
        self.link_from_ids = [link["from"] for link in self.links]
        self.link_to_ids = [link["to"] for link in self.links]

        self.max_leaf_items = 1
        self.node_kd_tree = self._make_kdtree_nodes(
            self.nodes,
            self.bounds["x"]["low"],
            self.bounds["x"]["high"],
            self.bounds["y"]["low"],
            self.bounds["y"]["high"],
            axis="x",
            max_leaf_items=self.max_leaf_items,
        )
        self.link_kd_tree = self._make_kdtree_links(
            self.links,
            self.bounds["x"]["low"],
            self.bounds["x"]["high"],
            self.bounds["y"]["low"],
            self.bounds["y"]["high"],
            axis="x",
            max_leaf_items=self.max_leaf_items,
        )
        return