from .waveguide_traces import WG_TMPL
import ipkiss3.all as i3
from ipkiss.technology import get_technology
import numpy as np


TECH = get_technology()


class BaseWaveguide(i3.RoundedWaveguide):
    _doc_properties = ["name"]

    class Layout(i3.RoundedWaveguide.Layout):
        _doc_properties = ["shape", "bend_radius"]

        def _default_shape(self):
            return [(0.0, 0.0), (20.0, 0.0)]


class StripWaveguide(BaseWaveguide):
    """
    Strip Waveguide
    """

    _name_prefix = "STRIP"
    _doc_properties = ["name", "trace_template"]

    trace_template = i3.TraceTemplateProperty(doc="Template of the waveguide")

    def _default_trace_template(self):
        return WG_TMPL(name=self.name + "_tt")

    class Layout(BaseWaveguide.Layout):

        _doc_properties = ["width", "bend_radius", "shape"]

        width = i3.PositiveNumberProperty(default=0.6, doc="Width of the core [um]")

        def _default_trace_template(self):
            lv = self.cell.trace_template.get_default_view(i3.LayoutView)
            lv.set(core_width=self.width)
            return lv

        def _default_bend_radius(self):
            return TECH.WG.BEND_RADIUS

        def _default_shape(self):
            return [(0.0, 0.0), (100.0, 0.0)]

        def _generate_ports(self, ports):
            insts = self.instances
            contents_ports = insts["contents"].ports
            ports = i3.expose_ports(
                instances=self.instances,
                port_name_map={"contents:" + p.name: p.name + "0" for p in contents_ports},
            )
            ports.sort(key=lambda x: x.name)
            return ports

    class Netlist(BaseWaveguide.Netlist):
        layout_view = i3.LayoutViewProperty(locked=True, doc="Layout view on which this netlist is based")

        def _default_layout_view(self):
            return self.cell.get_default_view(i3.LayoutView)

        def _generate_terms(self, terms):
            terms += i3.OpticalTerm(name="in0")
            terms += i3.OpticalTerm(name="out0")
            return terms

        def _generate_nets(self, nets):
            nets += i3.OpticalLink(self.terms["in0"], self.instances["contents"].terms["in"])
            nets += i3.OpticalLink(self.terms["out0"], self.instances["contents"].terms["out"])
            return nets


class StripWaveguide_369(StripWaveguide):
    """
    Strip Waveguide operating at 369nm wavelength.
    """

    _doc_properties = ["name"]

    trace_template = i3.LockedProperty()

    class Layout(StripWaveguide.Layout):
        _doc_properties = ["shape", "bend_radius"]
        width = i3.PositiveNumberProperty(default=0.6, doc="Width of the core [um]", locked=True)


class StripWaveguide_405(StripWaveguide):
    """
    Strip Waveguide operating at 405nm wavelength.
    """

    _doc_properties = ["name"]

    trace_template = i3.LockedProperty()

    class Layout(StripWaveguide.Layout):
        _doc_properties = ["shape", "bend_radius"]
        width = i3.PositiveNumberProperty(default=0.7, doc="Width of the core [um]", locked=True)


class _LinearTaper(i3.LinearWindowWaveguideTransition):
    def _default_start_trace_template(self):
        return WG_TMPL(name=self.name + "_start_tt")

    def _default_end_trace_template(self):
        return WG_TMPL(name=self.name + "_end_tt")

    class Layout(i3.LinearWindowWaveguideTransition.Layout):

        initial_width = i3.PositiveNumberProperty(default=0.5, doc="Width of the taper at the start.")
        final_width = i3.PositiveNumberProperty(default=1.0, doc="Width of the taper at the end.")
        length = i3.NonNegativeNumberProperty(doc="Length of the taper.")

        def _default_start_trace_template(self):
            lv = self.cell.start_trace_template.get_default_view(i3.LayoutView)
            lv.set(core_width=self.initial_width)
            return lv

        def _default_end_trace_template(self):
            lv = self.cell.end_trace_template.get_default_view(i3.LayoutView)
            lv.set(core_width=self.final_width)
            return lv

        def _default_start_position(self):
            return 0.0, 0.0

        def _default_end_position(self):
            return self.length, 0.0

        def _default_length(self):
            return 10.0

        def _generate_ports(self, ports):
            def_ports = super(_LinearTaper.Layout, self)._generate_ports(ports)
            return [p.modified_copy(name=p.name + "0") for p in def_ports]

    class Netlist(i3.LinearWindowWaveguideTransition.Netlist):
        def _generate_terms(self, terms):
            terms += i3.OpticalTerm(name="in0")
            terms += i3.OpticalTerm(name="out0")
            return terms


class TaperPCell(_LinearTaper):
    """
    Taper Waveguide.
    """

    _name_prefix = "Taper"
    _doc_properties = []

    class Layout(_LinearTaper.Layout):
        _doc_properties = ["initial_width", "final_width", "length"]


class SBend(i3.Waveguide):
    """
    S-Bend Waveguide.
    """

    _name_prefix = "SBend"
    _doc_properties = []

    trace_template = i3.TraceTemplateProperty(doc="Template of the waveguide", locked=True)

    def _default_trace_template(self):
        return WG_TMPL(name=self.name + "_tt")

    class Layout(i3.Waveguide.Layout):
        _doc_properties = ["width", "length", "height"]

        width = i3.PositiveNumberProperty(default=0.6, doc="Width of the sbend [um]")
        length = i3.PositiveNumberProperty(default=10, doc="Length of the sbend [um]")
        height = i3.NumberProperty(default=5, doc="Height of the sbend [um]")

        def _default_trace_template(self):
            lv = self.cell.trace_template.get_default_view(i3.LayoutView)
            lv.set(core_width=self.width)
            return lv

        def _default_shape(self):
            dy = self.height
            dx = self.length
            nsteps = round(i3.get_grids_per_unit())
            x = np.linspace(0.0, dx, nsteps)
            y = dy * (x / dx - 1 / (2 * np.pi) * np.sin(2 * np.pi * x / dx))
            shp = i3.Shape(list(zip(x, y)))
            shp.end_face_angle = 0.0
            shp.start_face_angle = 0.0
            return shp

        def _generate_ports(self, ports):
            insts = self.instances
            contents_ports = insts["contents"].ports
            ports = i3.expose_ports(
                instances=self.instances,
                port_name_map={"contents:" + p.name: p.name + "0" for p in contents_ports},
            )
            ports.sort(key=lambda x: x.name)
            return ports

    class Netlist(i3.Waveguide.Netlist):
        layout_view = i3.LayoutViewProperty(locked=True, doc="Layout view on which this netlist is based")

        def _default_layout_view(self):
            return self.cell.get_default_view(i3.LayoutView)

        def _generate_terms(self, terms):
            terms += i3.OpticalTerm(name="in0")
            terms += i3.OpticalTerm(name="out0")
            return terms

        def _generate_nets(self, nets):
            nets += i3.OpticalLink(self.terms["in0"], self.instances["contents"].terms["in"])
            nets += i3.OpticalLink(self.terms["out0"], self.instances["contents"].terms["out"])
            return nets


class EulerBendPcell(i3.Waveguide):
    """
    Euler Bend Waveguide.
    """

    _name_prefix = "EulerBend"
    _doc_properties = []

    trace_template = i3.TraceTemplateProperty(doc="Template of the waveguide", locked=True)

    def _default_trace_template(self):
        return WG_TMPL(name=self.name + "_tt")

    class Layout(i3.Waveguide.Layout):
        _doc_properties = ["min_radius", "p_factor", "arc_angle", "width"]

        width = i3.PositiveNumberProperty(default=0.6, doc="Width of the Euler bend [um]")
        arc_angle = i3.NumberProperty(default=90.0, doc="Angle of the Euler bend [degree]")
        p_factor = i3.PositiveNumberProperty(
            default=0.2,
            doc="Fraction of the bend having linearly increasing curvature",
        )
        min_radius = i3.PositiveNumberProperty(default=i3.TECH.TRACE.BEND_RADIUS, doc="Radius of the bend [um]")

        def _default_trace_template(self):
            lv = self.cell.trace_template.get_default_view(i3.LayoutView)
            lv.set(core_width=self.width)
            return lv

        def _default_shape(self):
            bend_radius = self.min_radius
            angle = self.arc_angle
            p = self.p_factor
            sgn = (angle / abs(angle)) if angle != 0 else 1
            angle = (sgn * self.arc_angle) % 360.0
            shape = i3.Shape([i3.Coord2(0.0, 0.0)])

            for n in range(int(angle / 90.0)):
                shape.append(shape[-1].move_polar_copy(bend_radius, sgn * n * 90.0))
                shape.append(shape[-1].move_polar_copy(bend_radius, sgn * (n + 1) * 90.0))

            angle_left = angle % 90.0
            shape.append(
                shape[-1].move_polar_copy(
                    bend_radius * np.tan(i3.DEG2RAD * angle_left / 2.0),
                    sgn * (angle - angle_left),
                )
            )
            shape.append(shape[-1].move_polar_copy(bend_radius * np.tan(i3.DEG2RAD * angle_left / 2.0), sgn * angle))
            ra = i3.EulerRoundingAlgorithm(p=p, use_effective_radius=True)
            rounded_shape = ra(original_shape=shape, radius=bend_radius)
            rounded_shape.start_face_angle = 0.0
            rounded_shape.end_face_angle = angle
            return rounded_shape

        def _generate_ports(self, ports):
            insts = self.instances
            contents_ports = insts["contents"].ports
            ports = i3.expose_ports(
                instances=self.instances,
                port_name_map={"contents:" + p.name: p.name + "0" for p in contents_ports},
            )
            ports.sort(key=lambda x: x.name)
            return ports

    class Netlist(i3.Waveguide.Netlist):
        layout_view = i3.LayoutViewProperty(locked=True, doc="Layout view on which this netlist is based")

        def _default_layout_view(self):
            return self.cell.get_default_view(i3.LayoutView)

        def _generate_terms(self, terms):
            terms += i3.OpticalTerm(name="in0")
            terms += i3.OpticalTerm(name="out0")
            return terms

        def _generate_nets(self, nets):
            nets += i3.OpticalLink(self.terms["in0"], self.instances["contents"].terms["in"])
            nets += i3.OpticalLink(self.terms["out0"], self.instances["contents"].terms["out"])
            return nets


class ArcPathPcell(BaseWaveguide):
    """
    Arc Waveguide
    """

    _name_prefix = "ArcPath"

    _doc_properties = []

    trace_template = i3.TraceTemplateProperty(doc="Template of the waveguide", locked=True)

    def _default_trace_template(self):
        return WG_TMPL(name=self.name + "_tt")

    class Layout(BaseWaveguide.Layout):
        _doc_properties = ["arc_angle", "radius", "width"]
        width = i3.PositiveNumberProperty(default=0.6, doc="Width of the arc path [um]")
        radius = i3.PositiveNumberProperty(default=i3.TECH.TRACE.BEND_RADIUS, doc="Radius of the waveguide [um]")
        arc_angle = i3.AngleProperty(
            default=90.0,
            doc="Angle of the arc in the interval, negative angle to go down and positive to go up [degree]",
        )

        def _default_trace_template(self):
            lv = self.cell.trace_template.get_default_view(i3.LayoutView)
            lv.set(core_width=self.width)
            return lv

        def _default_bend_radius(self):
            return self.radius

        def _default_shape(self):
            shp = i3.ShapeBendRelative(start_point=(0.0, 0.0), angle_amount=self.arc_angle, radius=self.radius)
            shp.start_face_angle = 0.0
            shp.end_face_angle = self.arc_angle
            return shp

        def _generate_ports(self, ports):
            insts = self.instances
            contents_ports = insts["contents"].ports
            ports = i3.expose_ports(
                instances=self.instances,
                port_name_map={"contents:" + p.name: p.name + "0" for p in contents_ports},
            )
            ports.sort(key=lambda x: x.name)
            return ports

    class Netlist(BaseWaveguide.Netlist):
        layout_view = i3.LayoutViewProperty(locked=True, doc="Layout view on which this netlist is based")

        def _default_layout_view(self):
            return self.cell.get_default_view(i3.LayoutView)

        def _generate_terms(self, terms):
            terms += i3.OpticalTerm(name="in0")
            terms += i3.OpticalTerm(name="out0")
            return terms

        def _generate_nets(self, nets):
            nets += i3.OpticalLink(self.terms["in0"], self.instances["contents"].terms["in"])
            nets += i3.OpticalLink(self.terms["out0"], self.instances["contents"].terms["out"])
            return nets
