from ..waveguides.cells import SBend, TaperPCell
from ..waveguides.waveguide_traces import WG_TMPL
import ipkiss3.all as i3


class YBranchPcell(i3.PCell):
    """
    YBranch splitter.
    """

    _doc_properties = []

    _name_prefix = "YBranch"
    sbend = i3.ChildCellProperty(locked=True)
    taper = i3.ChildCellProperty(locked=True)

    def _default_sbend(self):
        return SBend(name=self.name + "sbend")

    def _default_taper(self):
        return TaperPCell(name=self.name + "taper")

    class Layout(i3.LayoutView):
        _doc_properties = ["tap_length", "use_taper", "sbn_separation", "sbn_length", "width"]

        tap_length = i3.PositiveNumberProperty(default=10.0, doc="Length of the taper [um]")
        sbn_separation = i3.PositiveNumberProperty(default=10.0, doc="Separation of the S-bends [um]")
        sbn_length = i3.PositiveNumberProperty(default=8.0, doc="Length of the S-bends [um]")
        width = i3.PositiveNumberProperty(default=0.6, doc="Width of the waveguides [um]")
        use_taper = i3.BoolProperty(default=True, doc="Use Taper")

        def _default_sbend(self):
            lv = self.cell.sbend.get_default_view(self)
            lv.set(
                width=self.width,
                length=self.sbn_length,
                height=(self.sbn_separation + self.use_taper * self.width) / 2.0,
            )
            return lv

        def _default_taper(self):
            lv = self.cell.taper.get_default_view(self)
            lv.set(
                initial_width=self.width,
                final_width=self.width + self.use_taper * self.width,
                length=self.tap_length,
            )
            return lv

        def _generate_elements(self, elems):
            taper = self.taper
            sbend = self.sbend
            sbn_l = self.sbn_length
            layout = i3.place_and_route(
                insts={
                    "taper": taper,
                    "sbend1": sbend,
                    "sbend2": sbend,
                },
                specs=[
                    i3.FlipV("sbend2"),
                    i3.Place("taper", (0, 0)),
                    i3.Place("sbend1", (self.tap_length, (self.use_taper * self.width) / 2.0)),
                    i3.Place("sbend2", (self.tap_length, -(self.use_taper * self.width) / 2.0)),
                ],
            )
            elem_list = []
            for elem in layout:
                if elem == "sbend1" or elem == "sbend2":
                    elem_list.append(layout[elem])
                    pt1 = layout["sbend1"].ports["in0"].position
                    pt2 = layout["sbend2"].ports["in0"].position
                else:
                    elems += layout[elem].flat_copy()
            angle_pt_list = list(i3.get_acute_angle_points(elem_list).values())
            if angle_pt_list:
                ul_pt = pt1
                bl_pt = pt2
                temp_elems = i3.Shape(
                    points=[
                        ul_pt,
                        bl_pt,
                        (bl_pt[0] + sbn_l / 10.0, bl_pt[1]),
                        (ul_pt[0] + sbn_l / 10.0, ul_pt[1]),
                    ]
                )
                elems += i3.Boundary(layer=i3.TECH.PPLAYER.WG, shape=temp_elems)
            final_elems = i3.merge_elements(elem_list + elems, i3.TECH.PPLAYER.WG)
            for elem in final_elems:
                elem.shape = elem.shape.remove_straight_angles()
            return final_elems

        def _generate_ports(self, ports):
            tt = WG_TMPL(name=self.name + "_tt").Layout(core_width=self.width).cell
            ports += i3.OpticalPort(
                name="in0",
                position=(0.0, 0.0),
                angle=180.0,
                trace_template=tt,
            )
            ports += i3.OpticalPort(
                name="out0",
                position=(
                    self.tap_length + self.sbn_length,
                    -(self.sbn_separation + 2 * self.use_taper * self.width) / 2.0,
                ),
                angle=0.0,
                trace_template=tt,
            )
            ports += i3.OpticalPort(
                name="out1",
                position=(
                    self.tap_length + self.sbn_length,
                    (self.sbn_separation + 2 * self.use_taper * self.width) / 2.0,
                ),
                angle=0.0,
                trace_template=tt,
            )
            return ports

    class Netlist(i3.NetlistFromLayout):
        pass
