from ..waveguides.cells import StripWaveguide, SBend, TaperPCell
from ..waveguides.waveguide_traces import WG_TMPL
import ipkiss3.all as i3


class MMI_1x2Pcell(i3.PCell):
    """
    1x2 MMI.
    """

    _doc_properties = []
    _name_prefix = "MMI_1x2"
    straight = i3.ChildCellProperty(locked=True)
    taper = i3.ChildCellProperty(locked=True)
    sbend = i3.ChildCellProperty(locked=True)

    def _default_sbend(self):
        return SBend(name=self.name + "sbend")

    def _default_straight(self):
        return StripWaveguide(name=self.name + "straight")

    def _default_taper(self):
        return TaperPCell(name=self.name + "taper")

    class Layout(i3.LayoutView):
        _doc_properties = [
            "tap_length",
            "tap_width_out",
            "inp_separation",
            "inp_width",
            "mmi_separation",
            "mmi_width",
            "mmi_length",
            "sbn_length",
        ]

        tap_length = i3.PositiveNumberProperty(doc="Length of tapers", default=15)
        tap_width_out = i3.PositiveNumberProperty(doc="Output width of tapers", default=5)
        inp_separation = i3.PositiveNumberProperty(doc="Separation b/w output ports", default=20)
        inp_width = i3.PositiveNumberProperty(doc="Input width of tapers", default=0.6)
        mmi_separation = i3.PositiveNumberProperty(doc="Separation b/w output tapers", default=5)
        mmi_width = i3.PositiveNumberProperty(doc="Width of MMI", default=10)
        mmi_length = i3.PositiveNumberProperty(doc="Length of MMI", default=25)
        sbn_length = i3.PositiveNumberProperty(doc="Length of sbends at the output", default=20)

        def _default_taper(self):
            lv = self.cell.taper.get_default_view(self)
            lv.set(
                initial_width=self.inp_width,
                final_width=self.tap_width_out,
                length=self.tap_length,
            )
            return lv

        def _default_straight(self):
            lv = self.cell.straight.get_default_view(self)
            lv.set(width=self.mmi_width, shape=[(0.0, 0.0), (self.mmi_length, 0.0)])
            return lv

        def _default_sbend(self):
            lv = self.cell.sbend.get_default_view(self)
            lv.set(
                width=self.inp_width,
                length=self.sbn_length,
                height=(self.inp_separation - self.mmi_separation) / 2.0,
            )
            return lv

        def _generate_elements(self, elems):
            taper = self.taper
            sbend = self.sbend
            straight = self.straight

            mmi_length = self.mmi_length
            tap_length = self.tap_length
            mmi_gap = self.mmi_separation

            extend_flag = False
            layout = i3.place_and_route(
                insts={
                    "taper1": taper,
                    "taper2": taper,
                    "taper3": taper,
                    "sbend1": sbend,
                    "sbend2": sbend,
                    "straight": straight,
                },
                specs=[
                    i3.FlipV("sbend2"),
                    i3.FlipH("taper2"),
                    i3.FlipH("taper3"),
                    i3.Place("straight", (tap_length, 0)),
                    i3.Place("taper1", (0, 0)),
                    i3.Place("taper2", (mmi_length + 2 * tap_length, mmi_gap / 2.0)),
                    i3.Place("taper3", (mmi_length + 2 * tap_length, -mmi_gap / 2.0)),
                    i3.Place("sbend1", (mmi_length + 2 * tap_length, mmi_gap / 2.0)),
                    i3.Place("sbend2", (mmi_length + 2 * tap_length, -mmi_gap / 2.0)),
                ],
            )
            elem_list = []
            for elem in layout:
                if elem == "taper2" or elem == "taper3":
                    elem_list.append(layout[elem])
                else:
                    if elem == "sbend1":
                        sbend1_port = layout[elem].in_ports[0].position
                    elif elem == "sbend2":
                        sbend2_port = layout[elem].in_ports[0].position
                    elems += layout[elem].flat_copy()

            angle_pt_list = list(i3.get_acute_angle_points(elem_list).values())
            if angle_pt_list:
                if len(angle_pt_list[0]) == 1:
                    extend_flag = True
                    pt1 = angle_pt_list[0][0][0]
                else:
                    pts = []
                    right_most_pt = mmi_length + tap_length
                    for i in range(len(angle_pt_list[0])):
                        pt = angle_pt_list[0][i][0]
                        if pt[0] > right_most_pt:
                            pt1 = pt
                            right_most_pt = pt[0]
                            if pt[1] == 0:
                                extend_flag = True
                        if pt in pts:
                            extend_flag = True
                            pt1 = pt
                            break
                        pts.append(pt)
                right_pt = pt1
            if extend_flag:
                temp_elems = i3.Shape(
                    points=[
                        right_pt,
                        (right_pt[0] + tap_length / 10.0, sbend1_port[1]),
                        (right_pt[0] + tap_length / 10.0, sbend2_port[1]),
                    ]
                )
                elem_list += i3.Boundary(layer=i3.TECH.PPLAYER.WG, shape=temp_elems)
            final_elems = i3.merge_elements(elem_list, i3.TECH.PPLAYER.WG)
            final_elems += elems
            for elem in final_elems:
                elem.shape = elem.shape.remove_straight_angles()
            return final_elems

        def _generate_ports(self, ports):
            tt = WG_TMPL(name=self.name + "_tt").Layout(core_width=self.inp_width).cell
            ports += i3.OpticalPort(
                name="in0",
                position=(0.0, 0.0),
                angle=180.0,
                trace_template=tt,
            )
            ports += i3.OpticalPort(
                name="out0",
                position=(
                    self.tap_length + self.mmi_length + self.tap_length + self.sbn_length,
                    -self.inp_separation / 2.0,
                ),
                angle=0.0,
                trace_template=tt,
            )
            ports += i3.OpticalPort(
                name="out1",
                position=(
                    self.tap_length + self.mmi_length + self.tap_length + self.sbn_length,
                    self.inp_separation / 2.0,
                ),
                angle=0.0,
                trace_template=tt,
            )
            return ports

    class Netlist(i3.NetlistFromLayout):
        pass


class MMI_2x2Pcell(i3.PCell):
    """
    2x2 MMI.
    """

    _doc_properties = []
    _name_prefix = "MMI_2x2"
    straight = i3.ChildCellProperty(locked=True)
    taper = i3.ChildCellProperty(locked=True)
    sbend = i3.ChildCellProperty(locked=True)

    def _default_sbend(self):
        return SBend(name=self.name + "sbend")

    def _default_straight(self):
        return StripWaveguide(name=self.name + "straight")

    def _default_taper(self):
        return TaperPCell(name=self.name + "taper")

    class Layout(i3.LayoutView):
        _doc_properties = [
            "tap_length",
            "tap_width_out",
            "inp_separation",
            "inp_width",
            "mmi_separation",
            "mmi_width",
            "mmi_length",
            "sbn_length",
        ]

        tap_length = i3.PositiveNumberProperty(doc="Length of tapers", default=6)
        tap_width_out = i3.PositiveNumberProperty(doc="Output width of tapers", default=5)
        inp_separation = i3.PositiveNumberProperty(doc="Separation b/w ports", default=20)
        inp_width = i3.PositiveNumberProperty(doc="Input width of tapers", default=0.6)
        mmi_separation = i3.PositiveNumberProperty(doc="Separation b/w output tapers", default=5)
        mmi_width = i3.PositiveNumberProperty(doc="Width of MMI", default=10)
        mmi_length = i3.PositiveNumberProperty(doc="Length of MMI", default=25)
        sbn_length = i3.PositiveNumberProperty(doc="Length of sbends at the output", default=20)

        def _default_taper(self):
            lv = self.cell.taper.get_default_view(self)
            lv.set(
                initial_width=self.inp_width,
                final_width=self.tap_width_out,
                length=self.tap_length,
            )
            return lv

        def _default_straight(self):
            lv = self.cell.straight.get_default_view(self)
            lv.set(width=self.mmi_width, shape=[(0.0, 0.0), (self.mmi_length, 0.0)])
            return lv

        def _default_sbend(self):
            lv = self.cell.sbend.get_default_view(self)
            lv.set(
                width=self.inp_width,
                length=self.sbn_length,
                height=(self.inp_separation - self.mmi_separation) / 2.0,
            )
            return lv

        def _generate_elements(self, elems):
            taper = self.taper
            sbend = self.sbend
            straight = self.straight

            mmi_separation = self.mmi_separation
            inp_separation = self.inp_separation
            mmi_length = self.mmi_length
            tap_length = self.tap_length
            sbn_length = self.sbn_length
            offset = inp_separation / 2.0

            r_extend_flag = False
            l_extend_flag = False
            layout = i3.place_and_route(
                insts={
                    "taper1": taper,
                    "taper2": taper,
                    "taper3": taper,
                    "taper4": taper,
                    "sbend1": sbend,
                    "sbend2": sbend,
                    "sbend3": sbend,
                    "sbend4": sbend,
                    "straight": straight,
                },
                specs=[
                    i3.FlipV("sbend1"),
                    i3.FlipV("sbend4"),
                    i3.FlipH("taper4"),
                    i3.FlipH("taper3"),
                    i3.Place("straight", (sbn_length + tap_length, offset)),
                    i3.Place("taper1", (sbn_length, mmi_separation / 2.0 + offset)),
                    i3.Place("taper2", (sbn_length, -mmi_separation / 2.0 + offset)),
                    i3.Place("taper3", (2 * tap_length + mmi_length + sbn_length, mmi_separation / 2.0 + offset)),
                    i3.Place("taper4", (2 * tap_length + mmi_length + sbn_length, -mmi_separation / 2.0 + offset)),
                    i3.Place("sbend1", (0, inp_separation / 2.0 + offset)),
                    i3.Place("sbend2", (0, -inp_separation / 2.0 + offset)),
                    i3.Place("sbend3", (mmi_length + 2 * tap_length + sbn_length, mmi_separation / 2.0 + offset)),
                    i3.Place("sbend4", (mmi_length + 2 * tap_length + sbn_length, -mmi_separation / 2.0 + offset)),
                ],
            )
            right_elem_list = []
            left_elem_list = []
            mmi_center_y = (
                layout["straight"].flat_copy()[0].size_info().south
                + layout["straight"].flat_copy()[0].size_info().north
            ) / 2.0
            for elem in layout:
                if elem == "taper3" or elem == "taper4":
                    right_elem_list.append(layout[elem])
                elif elem == "taper1" or elem == "taper2":
                    left_elem_list.append(layout[elem])
                else:
                    if elem == "sbend1":
                        sbend1_port = layout[elem].out_ports[0].position
                    elif elem == "sbend2":
                        sbend2_port = layout[elem].out_ports[0].position
                    elif elem == "sbend3":
                        sbend3_port = layout[elem].in_ports[0].position
                    elif elem == "sbend4":
                        sbend4_port = layout[elem].in_ports[0].position
                    elems += layout[elem].flat_copy()
            r_angle_pt_list = list(i3.get_acute_angle_points(right_elem_list).values())
            l_angle_pt_list = list(i3.get_acute_angle_points(left_elem_list).values())
            if r_angle_pt_list:
                if len(r_angle_pt_list[0]) == 1:
                    pt1 = r_angle_pt_list[0][0][0]
                    r_extend_flag = True
                else:
                    pts = []
                    right_most_pt = sbn_length + mmi_length + tap_length
                    for i in range(len(r_angle_pt_list[0])):
                        pt = r_angle_pt_list[0][i][0]
                        if pt[0] > right_most_pt:
                            pt1 = pt
                            right_most_pt = pt[0]
                            if pt[1] == mmi_center_y:
                                r_extend_flag = True
                        if pt in pts:
                            pt1 = pt
                            r_extend_flag = True
                            break
                        pts.append(pt)
                right_pt = pt1
            if l_angle_pt_list:
                if len(l_angle_pt_list[0]) == 1:
                    pt1 = l_angle_pt_list[0][0][0]
                    l_extend_flag = True
                else:
                    pts = []
                    left_most_pt = sbn_length + tap_length
                    for i in range(len(l_angle_pt_list[0])):
                        pt = l_angle_pt_list[0][i][0]
                        if pt[0] < left_most_pt:
                            pt1 = pt
                            left_most_pt = pt[0]
                            if pt[1] == mmi_center_y:
                                l_extend_flag = True
                        if pt in pts:
                            pt1 = pt
                            l_extend_flag = True
                            break
                        pts.append(pt)
                left_pt = pt1
            if r_extend_flag:
                temp_elems = i3.Shape(
                    points=[
                        right_pt,
                        (right_pt[0] + tap_length / 10.0, sbend3_port[1]),
                        (right_pt[0] + tap_length / 10.0, sbend4_port[1]),
                    ]
                )
                right_elem_list += i3.Boundary(layer=i3.TECH.PPLAYER.WG, shape=temp_elems)
            if l_extend_flag:
                temp_elems = i3.Shape(
                    points=[
                        left_pt,
                        (left_pt[0] - tap_length / 10.0, sbend1_port[1]),
                        (left_pt[0] - tap_length / 10.0, sbend2_port[1]),
                    ]
                )
                left_elem_list += i3.Boundary(layer=i3.TECH.PPLAYER.WG, shape=temp_elems)
            final_elems = i3.merge_elements(right_elem_list, i3.TECH.PPLAYER.WG)
            final_elems += i3.merge_elements(left_elem_list, i3.TECH.PPLAYER.WG)
            final_elems += elems
            for elem in final_elems:
                elem.shape = elem.shape.remove_straight_angles()
            return final_elems

        def _generate_ports(self, ports):
            tt = WG_TMPL(name=self.name + "_tt").Layout(core_width=self.inp_width).cell
            ports += i3.OpticalPort(
                name="in0",
                position=(0.0, 0.0),
                angle=180.0,
                trace_template=tt,
            )
            ports += i3.OpticalPort(
                name="in1",
                position=(0.0, self.inp_separation),
                angle=180.0,
                trace_template=tt,
            )
            ports += i3.OpticalPort(
                name="out0",
                position=(
                    self.mmi_length + 2 * self.tap_length + 2 * self.sbn_length,
                    0,
                ),
                angle=0.0,
                trace_template=tt,
            )
            ports += i3.OpticalPort(
                name="out1",
                position=(
                    self.mmi_length + 2 * self.tap_length + 2 * self.sbn_length,
                    self.inp_separation,
                ),
                angle=0.0,
                trace_template=tt,
            )
            return ports

    class Netlist(i3.NetlistFromLayout):
        pass


class MMI_nxmPcell(i3.PCell):
    """
    NxM MMI.
    """

    _doc_properties = []
    _name_prefix = "MMI_NxM"

    trace_template = i3.TraceTemplateProperty(doc="Template of the waveguide", locked=True)

    def _default_trace_template(self):
        return WG_TMPL(name=self.name + "_tt")

    class Layout(i3.LayoutView):

        _doc_properties = [
            "inp_width",
            "tap_length",
            "tap_width_out",
            "mmi_width",
            "mmi_length",
            "waveguide_offset_input",
            "waveguide_spacing_input",
            "waveguide_spacing_output",
            "n_inputs",
            "n_outputs",
        ]

        inp_width = i3.PositiveNumberProperty(doc="Input waveguide width of tapers", default=0.6)
        tap_length = i3.PositiveNumberProperty(doc="Length of tapers", default=6)
        tap_width_out = i3.PositiveNumberProperty(doc="Output width of tapers", default=3)
        mmi_width = i3.PositiveNumberProperty(default=10.0, doc="Width of the MMI section.")
        mmi_length = i3.PositiveNumberProperty(default=20.0, doc="Length of the MMI secion.")
        waveguide_offset_input = i3.NonNegativeNumberProperty(default=2.0, doc="Offset input with respect to center.")
        waveguide_spacing_input = i3.NonNegativeNumberProperty(doc="Spacing between the input waveguides.")
        waveguide_spacing_output = i3.PositiveNumberProperty(doc="Spacing between the output waveguides.")
        n_inputs = i3.PositiveIntProperty(default=1, doc="The number of inputs for the MMI")
        n_outputs = i3.PositiveIntProperty(default=2, doc="The number of outputs for the MMI")

        def _default_trace_template(self):
            lv = self.cell.trace_template.get_default_view(self)
            lv.set(core_width=self.inp_width)
            return lv

        def _default_waveguide_spacing_input(self):
            return self.tap_width_out + 1.0

        def _default_waveguide_spacing_output(self):
            return self.tap_width_out + 1.0

        def _generate_elements(self, elems):
            inp_width = self.trace_template.core_width
            mmi_length = self.mmi_length
            mmi_width = self.mmi_width
            tap_length = self.tap_length
            tap_width_out = self.tap_width_out
            waveguide_offset_input = self.waveguide_offset_input
            waveguide_spacing_input = self.waveguide_spacing_input
            waveguide_spacing_output = self.waveguide_spacing_output
            n_inputs = self.n_inputs
            n_outputs = self.n_outputs

            core_layer = self.trace_template.core_layer

            elems += i3.Rectangle(
                layer=core_layer,
                center=(0.5 * mmi_length, 0.0),
                box_size=(mmi_length, mmi_width),
            )
            input_start = waveguide_offset_input - (n_inputs - 1) / 2.0 * waveguide_spacing_input
            for waveguide in range(n_inputs):
                elems += i3.Wedge(
                    layer=core_layer,
                    begin_coord=(
                        -tap_length,
                        input_start + waveguide * waveguide_spacing_input,
                    ),
                    end_coord=(
                        0.0,
                        input_start + waveguide * waveguide_spacing_input,
                    ),
                    begin_width=inp_width,
                    end_width=tap_width_out,
                )
            output_start = -(n_outputs - 1) / 2.0 * waveguide_spacing_input

            for waveguide in range(n_outputs):
                elems += i3.Wedge(
                    layer=core_layer,
                    begin_coord=(
                        mmi_length,
                        output_start + waveguide * waveguide_spacing_output,
                    ),
                    end_coord=(
                        mmi_length + tap_length,
                        output_start + waveguide * waveguide_spacing_output,
                    ),
                    begin_width=tap_width_out,
                    end_width=inp_width,
                )

            return elems

        def _generate_ports(self, ports):
            inp_width = self.inp_width
            mmi_length = self.mmi_length
            tap_length = self.tap_length
            waveguide_offset_input = self.waveguide_offset_input
            waveguide_spacing_input = self.waveguide_spacing_input
            waveguide_spacing_output = self.waveguide_spacing_output
            n_inputs = self.n_inputs
            n_outputs = self.n_outputs
            trace_template = self.trace_template
            trace_template.set(core_width=inp_width)

            input_start = waveguide_offset_input - (n_inputs - 1) / 2.0 * waveguide_spacing_input
            for waveguide in range(n_inputs):
                ports += i3.OpticalPort(
                    name="in{}".format(waveguide),
                    position=(
                        -tap_length,
                        input_start + waveguide * waveguide_spacing_input,
                    ),
                    angle=180.0,
                    trace_template=trace_template,
                )

            output_start = -(n_outputs - 1) / 2.0 * waveguide_spacing_input
            for waveguide in range(n_outputs):
                ports += i3.OpticalPort(
                    name="out{}".format(waveguide),
                    position=(
                        mmi_length + tap_length,
                        output_start + waveguide * waveguide_spacing_output,
                    ),
                    angle=0.0,
                    trace_template=trace_template,
                )
            return ports

    class Netlist(i3.NetlistFromLayout):
        pass
