from PIL import Image
import matplotlib.pyplot as plt
from skimage import measure
import numpy as np
import os
from xml.dom.minidom import getDOMImplementation


def create_svg_path(contour):
    """
    Create an SVG path string from a contour.
    """

    d = "M {} {}".format(contour[0][1], contour[0][0])

    for (y, x) in contour[1:]:
        d += " L {} {}".format(x, y)

    d += " Z"
    return d


def contours_to_svg(contours, width, height):
    """
    Convert a list of contours to an SVG document.
    """
    impl = getDOMImplementation()

    svg_doc = impl.createDocument(None, "svg", None)
    svg_root = svg_doc.documentElement
    svg_root.setAttribute("width", str(width))
    svg_root.setAttribute("height", str(height))
    svg_root.setAttribute("xmlns", "http://www.w3.org/2000/svg")

    g_element = svg_doc.createElement("g")
    g_element.setAttribute("fill", "none")
    g_element.setAttribute("stroke", "black")
    g_element.setAttribute("stroke-width", "1")

    for contour in contours:
        path_data = create_svg_path(contour)
        path_element = svg_doc.createElement("path")
        path_element.setAttribute("d", path_data)
        g_element.appendChild(path_element)

    svg_root.appendChild(g_element)
    return svg_doc.toxml()


def convert_to_vector(image_path, vector_output_path, num):

    image = Image.open(image_path)

    gray_image = image.convert('L')

    image_np = np.array(gray_image)

    threshold_value = 128
    binary_image = image_np < threshold_value

    contours = measure.find_contours(binary_image, 0.7)

    image_width, image_height = binary_image.shape[::-1]

    svg_data = contours_to_svg(contours, image_width, image_height)

    svg_file_path = vector_output_path + f'/{num}.svg'
    with open(svg_file_path, 'w') as svg_file:
        svg_file.write(svg_data)


# input_path = '/Users/dakotagoldberg/cba/raspi_drawing_machine/code/stroke'
# output_path = '/Users/dakotagoldberg/cba/raspi_drawing_machine/code/vec'


# num = 0
# for filename in os.listdir(input_path):
#     num += 1
#     filepath = os.path.join(
#         input_path, filename)
#     # Check if it is a file
#     if os.path.isfile(filepath):
#         print(filename)

#         # name = "polite_cat.png"
#         convert_to_vector(f'{input_path}/{filename}',
#                           f'{output_path}', num)