from ..waveguides.cells import StripWaveguide, ArcPathPcell
import ipkiss3.all as i3
import numpy as np


class RingRes_1x1Pcell(i3.PCell):
    """
    1x1 Ring Resonator.
    """

    _doc_properties = []

    _name_prefix = "RingRes_1x1"
    arc = i3.ChildCellProperty(locked=True)
    straight_wg = i3.ChildCellProperty(locked=True)
    straight_h = i3.ChildCellProperty(locked=True)
    straight_v = i3.ChildCellProperty(locked=True)

    def _default_arc(self):
        return ArcPathPcell(name=self.name + "arc")

    def _default_straight_wg(self):
        return StripWaveguide(name=self.name + "straight_wg")

    def _default_straight_h(self):
        return StripWaveguide(name=self.name + "straight_h")

    def _default_straight_v(self):
        return StripWaveguide(name=self.name + "straight_v")

    class Layout(i3.LayoutView):
        _doc_properties = [
            "guide_width",
            "path_width",
            "straight_length",
            "straight_height",
            "bend_radius",
            "separation",
            "use_total_length",
            "total_length",
        ]

        guide_width = i3.PositiveNumberProperty(default=0.6, doc="Width of the straight waveguide")
        path_width = i3.PositiveNumberProperty(default=0.6, doc="Width of the ring")
        straight_length = i3.NonNegativeNumberProperty(default=0, doc="Length of the straight in the ring")
        straight_height = i3.NonNegativeNumberProperty(default=0, doc="Height of the straight in the ring")
        bend_radius = i3.PositiveNumberProperty(default=i3.TECH.TRACE.BEND_RADIUS, doc="Bend radius of the ring")
        separation = i3.PositiveNumberProperty(default=0.5, doc="Gap between straight and ring")
        use_total_length = i3.BoolProperty(default=False, doc="Use total_length parameter?")
        total_length = i3.PositiveNumberProperty(default=1100.0, doc="Total Path Length of the ring")

        def validate_properties(self):
            if self.use_total_length:
                min_length = 2 * np.pi * self.bend_radius + 2 * self.straight_length
                if self.total_length < min_length:
                    raise i3.PropertyValidationError(
                        self, "Total Ring Length should be larger than {}".format(min_length)
                    )
            return True

        def _default_arc(self):
            lv = self.cell.arc.get_default_view(self)
            lv.set(
                width=self.path_width,
                arc_angle=90,
                radius=self.bend_radius,
            )
            return lv

        def _default_straight_wg(self):
            lv = self.cell.straight_wg.get_default_view(self)
            lv.set(
                width=self.guide_width,
                shape=[(0.0, 0.0), (self.straight_length + self.path_width + 2 * self.bend_radius, 0.0)],
            )
            return lv

        def _default_straight_h(self):
            lv = self.cell.straight_h.get_default_view(self)
            lv.set(
                width=self.path_width,
                shape=[(0.0, 0.0), (self.straight_length, 0.0)],
            )
            return lv

        def _default_straight_v(self):
            straight_height = self.straight_height
            if self.use_total_length:
                straight_height = (self.total_length - 2 * np.pi * self.bend_radius - 2 * self.straight_length) / 2.0

            lv = self.cell.straight_v.get_default_view(self)
            lv.set(
                width=self.path_width,
                shape=[(0.0, 0.0), (0.0, straight_height)],
            )
            return lv

        def _generate_instances(self, insts):
            straight_height = self.straight_height
            straight_length = self.straight_length
            separation = self.separation
            path_width = self.path_width
            half_path_width = path_width / 2.0
            guide_width = self.guide_width
            bend_radius = self.bend_radius
            total_length = self.total_length

            arc = self.arc
            straight_wg = self.straight_wg
            straight_h = self.straight_h
            straight_v = self.straight_v

            total_gap = separation + (path_width + guide_width) / 2.0
            if self.use_total_length:
                straight_height = (total_length - 2 * np.pi * bend_radius - 2 * straight_length) / 2.0
            insts_dict = {
                "arc_bl": arc,
                "arc_br": arc,
                "arc_tr": arc,
                "arc_tl": arc,
                "straight_wg_b": straight_wg,
            }
            specs_array = [
                i3.FlipH("arc_bl"),
                i3.Place("straight_wg_b", (0, 0)),
                i3.Place(
                    "arc_br",
                    (bend_radius + straight_length + half_path_width, total_gap),
                ),
                i3.Place(
                    "arc_bl",
                    (bend_radius + half_path_width, total_gap),
                ),
                i3.Place(
                    "arc_tr",
                    (
                        2 * bend_radius + straight_length + half_path_width,
                        bend_radius + straight_height + total_gap,
                    ),
                    90,
                ),
                i3.Place(
                    "arc_tl",
                    (
                        bend_radius + half_path_width,
                        2 * bend_radius + straight_height + total_gap,
                    ),
                    180,
                ),
            ]

            if straight_length > 0:
                insts_dict.update(
                    [
                        ("straight_h_b", straight_h),
                        ("straight_h_t", straight_h),
                    ]
                )
                specs_array.append(i3.Join("arc_br:in0", "straight_h_b:out0"))
                specs_array.append(i3.Join("arc_tr:out0", "straight_h_t:out0"))
            if straight_height > 0:
                insts_dict.update(
                    [
                        ("straight_v_l", straight_v),
                        ("straight_v_r", straight_v),
                    ]
                )
                specs_array.append(i3.Join("arc_br:out0", "straight_v_r:in0"))
                specs_array.append(i3.Join("arc_bl:out0", "straight_v_l:in0"))

            insts += i3.place_and_route(
                insts=insts_dict,
                specs=specs_array,
            )
            return insts

        def _generate_ports(self, ports):
            exposed_ports = {
                "straight_wg_b:in0": "in0",
                "straight_wg_b:out0": "out0",
            }
            ports += i3.expose_ports(self.instances, exposed_ports)
            return ports

    class Netlist(i3.NetlistFromLayout):
        pass


class RingRes_2x2Pcell(i3.PCell):
    """
    2x2 Ring Resonator.
    """

    _doc_properties = []

    _name_prefix = "RingRes_2x2"
    arc = i3.ChildCellProperty(locked=True)
    straight_wg = i3.ChildCellProperty(locked=True)
    straight_h = i3.ChildCellProperty(locked=True)
    straight_v = i3.ChildCellProperty(locked=True)

    def _default_arc(self):
        return ArcPathPcell(name=self.name + "arc")

    def _default_straight_wg(self):
        return StripWaveguide(name=self.name + "straight_wg")

    def _default_straight_h(self):
        return StripWaveguide(name=self.name + "straight_h")

    def _default_straight_v(self):
        return StripWaveguide(name=self.name + "straight_v")

    class Layout(i3.LayoutView):
        _doc_properties = [
            "guide_width",
            "path_width",
            "straight_length",
            "straight_height",
            "bend_radius",
            "separation",
            "use_total_length",
            "total_length",
        ]

        guide_width = i3.PositiveNumberProperty(default=0.6, doc="Width of the straight waveguide")
        path_width = i3.PositiveNumberProperty(default=0.6, doc="Width of the ring")
        straight_length = i3.NonNegativeNumberProperty(default=0, doc="Length of the straight in the ring")
        straight_height = i3.NonNegativeNumberProperty(default=0, doc="Height of the straight in the ring")
        bend_radius = i3.PositiveNumberProperty(default=i3.TECH.TRACE.BEND_RADIUS, doc="Bend radius of the ring")
        separation = i3.PositiveNumberProperty(default=0.5, doc="Gap between straight and ring")
        use_total_length = i3.BoolProperty(default=False, doc="Use total_length parameter?")
        total_length = i3.PositiveNumberProperty(default=1100.0, doc="Total Path Length of the ring")

        def validate_properties(self):
            if self.use_total_length:
                min_length = 2 * np.pi * self.bend_radius + 2 * self.straight_length
                if self.total_length < min_length:
                    raise i3.PropertyValidationError(
                        self, "Total Ring Length should be larger than {}".format(min_length)
                    )
            return True

        def _default_arc(self):
            lv = self.cell.arc.get_default_view(self)
            lv.set(
                width=self.path_width,
                arc_angle=90,
                radius=self.bend_radius,
            )
            return lv

        def _default_straight_wg(self):
            lv = self.cell.straight_wg.get_default_view(self)
            lv.set(
                width=self.guide_width,
                shape=[(0.0, 0.0), (self.straight_length + self.path_width + 2 * self.bend_radius, 0.0)],
            )
            return lv

        def _default_straight_h(self):
            lv = self.cell.straight_h.get_default_view(self)
            lv.set(
                width=self.path_width,
                shape=[(0.0, 0.0), (self.straight_length, 0.0)],
            )
            return lv

        def _default_straight_v(self):
            straight_height = self.straight_height
            if self.use_total_length:
                straight_height = (self.total_length - 2 * np.pi * self.bend_radius - 2 * self.straight_length) / 2.0

            lv = self.cell.straight_v.get_default_view(self)
            lv.set(
                width=self.path_width,
                shape=[(0.0, 0.0), (0.0, straight_height)],
            )
            return lv

        def _generate_instances(self, insts):
            straight_height = self.straight_height
            straight_length = self.straight_length
            separation = self.separation
            path_width = self.path_width
            half_path_width = path_width / 2.0
            guide_width = self.guide_width
            bend_radius = self.bend_radius
            total_length = self.total_length

            arc = self.arc
            straight_wg = self.straight_wg
            straight_h = self.straight_h
            straight_v = self.straight_v

            total_gap = separation + (path_width + guide_width) / 2.0
            if self.use_total_length:
                straight_height = (total_length - 2 * np.pi * bend_radius - 2 * straight_length) / 2.0
            insts_dict = {
                "arc_bl": arc,
                "arc_br": arc,
                "arc_tr": arc,
                "arc_tl": arc,
                "straight_wg_b": straight_wg,
                "straight_wg_t": straight_wg,
            }
            specs_array = [
                i3.FlipH("arc_bl"),
                i3.Place("straight_wg_b", (0, 0)),
                i3.Place("straight_wg_t", (0, 2 * bend_radius + straight_height + 2 * total_gap)),
                i3.Place(
                    "arc_br",
                    (bend_radius + straight_length + half_path_width, total_gap),
                ),
                i3.Place(
                    "arc_bl",
                    (bend_radius + half_path_width, total_gap),
                ),
                i3.Place(
                    "arc_tr",
                    (
                        2 * bend_radius + straight_length + half_path_width,
                        bend_radius + straight_height + total_gap,
                    ),
                    90,
                ),
                i3.Place(
                    "arc_tl",
                    (
                        bend_radius + half_path_width,
                        2 * bend_radius + straight_height + total_gap,
                    ),
                    180,
                ),
            ]

            if straight_length > 0:
                insts_dict.update(
                    [
                        ("straight_h_b", straight_h),
                        ("straight_h_t", straight_h),
                    ]
                )
                specs_array.append(i3.Join("arc_br:in0", "straight_h_b:out0"))
                specs_array.append(i3.Join("arc_tr:out0", "straight_h_t:out0"))
            if straight_height > 0:
                insts_dict.update(
                    [
                        ("straight_v_l", straight_v),
                        ("straight_v_r", straight_v),
                    ]
                )
                specs_array.append(i3.Join("arc_br:out0", "straight_v_r:in0"))
                specs_array.append(i3.Join("arc_bl:out0", "straight_v_l:in0"))

            insts += i3.place_and_route(
                insts=insts_dict,
                specs=specs_array,
            )
            return insts

        def _generate_ports(self, ports):
            exposed_ports = {
                "straight_wg_b:in0": "in0",
                "straight_wg_b:out0": "out0",
                "straight_wg_t:in0": "in1",
                "straight_wg_t:out0": "out1",
            }
            ports += i3.expose_ports(self.instances, exposed_ports)
            return ports

    class Netlist(i3.NetlistFromLayout):
        pass
