from ..waveguides.cells import StripWaveguide, SBend
import ipkiss3.all as i3


class DirectionalCouplerPcell(i3.PCell):
    """
    Directional Coupler.
    """

    _doc_properties = []
    _name_prefix = "DirectionalCoupler"
    straight = i3.ChildCellProperty(locked=True)
    sbend = i3.ChildCellProperty(locked=True)

    def _default_sbend(self):
        return SBend(name=self.name + "sbend")

    def _default_straight(self):
        return StripWaveguide(name=self.name + "straight")

    class Layout(i3.LayoutView):
        _doc_properties = ["length", "separation", "inp_length", "inp_separation", "width"]

        length = i3.PositiveNumberProperty(default=30.0, doc="Length of the Coupling section [um]")
        separation = i3.PositiveNumberProperty(default=0.5, doc="Gap of the Coupler [um]")
        inp_length = i3.PositiveNumberProperty(default=20.0, doc="Length of the S-bends [um]")
        inp_separation = i3.PositiveNumberProperty(default=10.0, doc="Separation of the S-bends [um]")
        width = i3.PositiveNumberProperty(default=0.6, doc="Width of the waveguides")

        def _default_sbend(self):
            lv = self.cell.sbend.get_default_view(self)
            lv.set(
                width=self.width,
                length=self.inp_length,
                height=(self.inp_separation - self.separation) / 2.0,
            )
            return lv

        def _default_straight(self):
            lv = self.cell.straight.get_default_view(self)
            lv.set(width=self.width, shape=[(0.0, 0.0), (self.length, 0.0)])
            return lv

        def _generate_instances(self, insts):
            straight = self.straight
            sbend = self.sbend
            insts += i3.place_and_route(
                insts={
                    "straight1": straight,
                    "straight2": straight,
                    "sbend1": sbend,
                    "sbend2": sbend,
                    "sbend3": sbend,
                    "sbend4": sbend,
                },
                specs=[
                    i3.FlipV("sbend2"),
                    i3.FlipV("sbend4"),
                    i3.Place("sbend1", (0, 0)),
                    i3.Join("sbend1:out0", "straight1:in0"),
                    i3.Join("straight1:out0", "sbend4:in0"),
                    i3.PlaceRelative("straight2:in0", "straight1:in0", (0, self.width + self.separation)),
                    i3.Join("straight2:in0", "sbend2:out0"),
                    i3.Join("straight2:out0", "sbend3:in0"),
                ],
            )
            return insts

        def _generate_ports(self, ports):
            exposed_ports = {
                "sbend1:in0": "in0",
                "sbend2:in0": "in1",
                "sbend3:out0": "out1",
                "sbend4:out0": "out0",
            }
            ports += i3.expose_ports(self.instances, exposed_ports)
            return ports

    class Netlist(i3.NetlistFromLayout):
        pass
