#!/usr/bin/python3
import json
import os
import subprocess
import sys
import tempfile
import re

from textual.app import App
from textual.containers import Container, Horizontal, VerticalScroll
from textual.widgets import Tree, Footer, Static
from textual.widgets import Markdown

def format_urls_as_markdown(text):
    """Convert plain URLs in text to markdown links, skipping already formatted markdown links."""
    # Skip URLs that are already in markdown format [text](url)
    if re.search(r'\[.+?\]\(.+?\)', text):
        return text
    
    # Convert plain URLs to markdown links
    url_pattern = r'https?://(?:[-\w.]|(?:%[\da-fA-F]{2}))+(?:/[^)\s]*)?'
    return re.sub(url_pattern, lambda m: f'[{m.group()}]({m.group()})', text)

def is_grype_installed():
    """Check if the grype binary is available in the system's PATH."""
    return any(
        os.path.isfile(os.path.join(path, "grype"))
        and os.access(os.path.join(path, "grype"), os.X_OK)
        for path in os.environ["PATH"].split(os.pathsep)
    )

def prompt_install_grype():
    """Prompt the user to install grype if it's not installed."""
    response = input(
        "The grype binary is not located in $PATH. Would you like to install it now? [y/N]: "
    ).strip().lower()
    return response == "y"

def install_grype():
    """Run the one-liner command to install grype."""
    try:
        # Use shell=True to execute the one-liner properly
        subprocess.run(
            "curl -sSfL https://raw.githubusercontent.com/anchore/grype/main/install.sh | sh -s -- -b /usr/local/bin",
            shell=True,
            check=True,
        )
        print("Grype successfully installed.")
    except subprocess.CalledProcessError as e:
        print(f"Failed to install grype: {e}")
        sys.exit(1)

class Grummage(App):
    BINDINGS = [
        ("v", "load_tree_by_vulnerability", "by Vuln"),
        ("n", "load_tree_by_package_name", "by Pkg Name"),
        ("s", "load_tree_by_severity", "by Severity"),
        ("t", "load_tree_by_package_type", "by Type"),
        ("e", "explain_vulnerability", "Explain Vuln"),
        ("q", "quit", "Quit"),
        ("j", "simulate_key('down')", "Move down"),
        ("k", "simulate_key('up')", "Move up"),
        ("h", "simulate_key('left')", "Move left"),
        ("l", "simulate_key('right')", "Move right"),
    ]

    def __init__(self, sbom_file=None):
        super().__init__()
        self.sbom_file = sbom_file
        self.vulnerability_report = None
        self.debug_log_file = open("debug_log.txt", "w")

    def quit(self):
        """Exit the application."""
        self.exit()

    def debug_log(self, message):
        """Helper method to write debug messages to log file."""
        self.debug_log_file.write(message + "\n")
        self.debug_log_file.flush()  # Ensure immediate write

    async def on_mount(self):
        self.debug_log("on_mount: Starting application setup")

        # Initialize widgets for the tree view, details display, and status bar
        self.tree_view = Tree("Vulnerabilities")
        self.details_display = Markdown("Select a node for more details.")
        self.status_bar = Static("Status: Initializing...")

        # Create containers for tree view (left) and details (right) in a horizontal layout
        tree_container = Container(self.tree_view)
        details_container = VerticalScroll(self.details_display)
        # Set widths to maintain a 30/70 side-by-side layout
        tree_container.styles.width = "35%"
        details_container.styles.width = "65%"
        tree_container.styles.height = "98%"
        details_container.styles.height = "98%"

        # Use Horizontal container for side-by-side layout
        main_layout = Horizontal(tree_container, details_container)
        main_layout.styles.height = "98%"

        # Mount the main layout and the status bar at the bottom
        await self.mount(main_layout)
        await self.mount(self.status_bar)
        await self.mount(Footer())
        self.debug_log("on_mount: Layout mounted")

        # Load the SBOM from file or stdin
        await self.load_sbom()

    async def load_sbom(self):
        """Load the SBOM from a file or stdin."""
        if self.sbom_file:
            # Load SBOM from the provided file path
            self.debug_log(f"Loading SBOM from file: {self.sbom_file}")
            sbom_json = self.load_json(self.sbom_file)
        else:
            # Read SBOM from stdin
            self.debug_log("Loading SBOM from stdin")
            try:
                sbom_json = json.load(sys.stdin)
            except json.JSONDecodeError as e:
                self.debug_log(f"Error reading SBOM from stdin: {e}")
                self.status_bar.update("Status: Failed to read SBOM from stdin.")
                return

        # Run Grype analysis on the loaded SBOM JSON
        self.vulnerability_report = self.call_grype(sbom_json)
        if self.vulnerability_report and "matches" in self.vulnerability_report:
            self.load_tree_by_package_name()
            self.status_bar.update("Status: Vulnerability data loaded. Press N, T, V, S to change views, or E to explain.")
            self.debug_log("Vulnerability data loaded into tree")
        else:
            self.status_bar.update("Status: No vulnerabilities found or unable to load data.")
            self.debug_log("No vulnerability data found")

    def load_json(self, file_path):
        """Load SBOM JSON from a file."""
        try:
            with open(file_path, "r") as file:
                return json.load(file)
        except Exception as e:
            self.debug_log(f"Error loading SBOM JSON: {e}")
            return None

    def call_grype(self, sbom_json):
        """Call Grype with the SBOM JSON to generate a vulnerability report."""
        try:
            # Create a temporary file to store the SBOM JSON
            with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.json') as temp_file:
                json.dump(sbom_json, temp_file)
                temp_file_path = temp_file.name

            # Call Grype using the temporary JSON file path
            result = subprocess.run(
                ["grype", temp_file_path, "-o", "json"],
                capture_output=True,
                text=True
            )

            # Print stdout and stderr for debugging
            #print("Grype STDOUT:", result.stdout)
            #print("Grype STDERR:", result.stderr)

            if result.returncode != 0:
                print("Grype encountered an error:", result.stderr)
                return None

            # Return the parsed JSON if no errors occurred
            return json.loads(result.stdout)

        except Exception as e:
            print("Error running Grype:", e)
            return None
    
    def load_tree_by_package_name(self):
        """Display vulnerabilities organized by package name."""
        self.tree_view.clear()
        file_name_map = {}
        for match in self.vulnerability_report["matches"]:
            file_name = match["artifact"]["name"]
            file_name_map.setdefault(file_name, []).append(match)

        for file_name, matches in file_name_map.items():
            file_node = self.tree_view.root.add(file_name)
            for match in matches:
                vuln_id = match["vulnerability"]["id"]
                vuln_node = file_node.add_leaf(f"{vuln_id}")
                
                # Store detailed info for right-hand pane display
                vuln_node.data = {
                    "id": vuln_id,
                    "package_name": match["artifact"]["name"],
                    "package_version": match["artifact"]["version"],
                    "severity": match["vulnerability"].get("severity", "Unknown"),
                    "fix_version": match["vulnerability"].get("fix", {}).get("versions", ["None"]),
                    "related": match.get("relatedVulnerabilities", [])
                }

    def load_tree_by_type(self):
        """Display vulnerabilities organized by package type, with package names under each type."""
        self.tree_view.clear()
        type_map = {}

        # Organize matches by package type and then by package name
        for match in self.vulnerability_report["matches"]:
            pkg_type = match["artifact"]["type"]
            pkg_name = match["artifact"]["name"]
            type_map.setdefault(pkg_type, {}).setdefault(pkg_name, []).append(match)

        # Build the tree view with the new structure
        for pkg_type, packages in type_map.items():
            type_node = self.tree_view.root.add(pkg_type)  # Add package type node
            for pkg_name, matches in packages.items():
                package_node = type_node.add(pkg_name)  # Add package name node
                for match in matches:
                    vuln_id = match["vulnerability"]["id"]
                    vuln_node = package_node.add_leaf(f"{vuln_id}")  # Add vulnerability ID under package name
                    
                    # Store detailed info for right-hand pane display
                    vuln_node.data = {
                        "id": vuln_id,
                        "package_name": match["artifact"]["name"],
                        "package_version": match["artifact"]["version"],
                        "severity": match["vulnerability"].get("severity", "Unknown"),
                        "fix_version": match["vulnerability"].get("fix", {}).get("versions", ["None"]),
                        "related": match.get("relatedVulnerabilities", [])
                    }


    def load_tree_by_vulnerability(self):
        """Display vulnerabilities organized by vulnerability ID."""
        self.tree_view.clear()
        vuln_map = {}
        for match in self.vulnerability_report["matches"]:
            vuln_id = match["vulnerability"]["id"]
            vuln_map.setdefault(vuln_id, []).append(match)

        for vuln_id, matches in vuln_map.items():
            vuln_node = self.tree_view.root.add(vuln_id)
            for match in matches:
                pkg_name = match["artifact"]["name"]
                package_node = vuln_node.add_leaf(f"{pkg_name}")
                
                # Store detailed info for right-hand pane display
                package_node.data = {
                    "id": vuln_id,
                    "package_name": match["artifact"]["name"],
                    "package_version": match["artifact"]["version"],
                    "severity": match["vulnerability"].get("severity", "Unknown"),
                    "fix_version": match["vulnerability"].get("fix", {}).get("versions", ["None"]),
                    "related": match.get("relatedVulnerabilities", [])
                }

    def load_tree_by_severity(self):
        """Display vulnerabilities organized by severity, in fixed order."""
        self.tree_view.clear()
        # Define the desired order for severities
        severity_order = ["Critical", "High", "Medium", "Low", "Negligible", "Unknown"]
        
        # Create a dictionary mapping each severity to its matches
        severity_map = {severity: [] for severity in severity_order}
        for match in self.vulnerability_report["matches"]:
            severity = match["vulnerability"].get("severity", "Unknown")
            if severity not in severity_map:
                severity = "Unknown"  # Assign unknown severity if it's not one of the predefined categories
            severity_map[severity].append(match)
        
        # Add nodes in the specified order with full vulnerability data for each node
        for severity in severity_order:
            if severity_map[severity]:  # Only add if there are matches
                severity_node = self.tree_view.root.add(severity)
                for match in severity_map[severity]:
                    vuln_id = match["vulnerability"]["id"]
                    package_name = match["artifact"]["name"]
                    vuln_node = severity_node.add_leaf(f"{vuln_id} ({package_name})")
                    
                    # Store detailed info in each node for later access in the right-hand pane
                    vuln_node.data = {
                        "id": vuln_id,
                        "package_name": match["artifact"]["name"],
                        "package_version": match["artifact"]["version"],
                        "severity": severity,
                        "fix_version": match["vulnerability"].get("fix", {}).get("versions", ["None"]),
                        "related": match.get("relatedVulnerabilities", [])
                    }



    async def on_key(self, event):
        """Handle key press events to switch views."""
        key = event.key.lower()
        if key == "n":
            self.load_tree_by_package_name()
            self.status_bar.update("Status: Viewing by package name.")
        elif key == "t":
            self.load_tree_by_type()
            self.status_bar.update("Status: Viewing by package type.")
        elif key == "v":
            self.load_tree_by_vulnerability()
            self.status_bar.update("Status: Viewing by vulnerability ID.")
        elif key == "s":
            self.load_tree_by_severity()
            self.status_bar.update("Status: Viewing by severity.")
        elif key == "e" and self.selected_vuln_id and self.detailed_text:
             self.status_bar.update(f"Status: Explaining {self.selected_vuln_id} in {self.selected_package_name} ({self.selected_package_version})")
             await self.explain_vulnerability(self.selected_vuln_id)


    async def on_tree_node_selected(self, event):
        """Show detailed information for the selected vulnerability."""
        node = event.node
        if node.data:
            details = node.data
            self.selected_vuln_id = details["id"]
            self.selected_package_name = details["package_name"]
            self.selected_package_version = details["package_version"]
            self.status_bar.update(f"Status: Selected {details['id']} in {details['package_name']} ({details['package_version']})")
            detail_text = (
                f"# {details['id']}\n\n"
                f"**Package:** {details['package_name']} ({details['package_version']})\n\n"
                f"**Severity:** {details['severity']}\n\n"
                f"**Fix Version:** {', '.join(details['fix_version'])}\n\n"
                f"**Related Vulnerabilities:**\n\n"
            )
            for related in details["related"]:
                detail_text += f"* {related['id']} ({related['dataSource']})\n"
            
            # Convert URLs to markdown links
            detail_text = format_urls_as_markdown(detail_text)
            
            self.debug_log(f"Displaying details for {details['id']}")
            self.details_display.update(detail_text)
            self.detailed_text = detail_text
        else:
            self.details_display.update("No data found for selected node.")
            self.selected_vuln_id = None
            self.selected_package_name = None
            self.selected_package_version = None
            self.debug_log("No data found for selected node.")

    def on_unmount(self):
        """Close the log file when the application exits."""
        self.debug_log_file.close()

    async def explain_vulnerability(self, vuln_id):
        """Call Grype to explain a vulnerability by its ID and display the output."""
        try:
            # First, run Grype on the user-provided SBOM file to get the JSON report
            analyze_result = subprocess.run(
                ["grype", self.sbom_file, "-o", "json"],
                capture_output=True,
                text=True
            )

            # Check if the SBOM analysis was successful
            if analyze_result.returncode != 0:
                self.details_display.update(f"Error analyzing SBOM: {analyze_result.stderr}")
                self.debug_log(f"Error analyzing SBOM for explanation: {analyze_result.stderr}")
                return

            # Run Grype's explain command with the specific vulnerability ID
            explain_result = subprocess.run(
                ["grype", "explain", "--id", vuln_id],
                input=analyze_result.stdout,  # Pass the JSON output from the first run as input
                capture_output=True,
                text=True
            )

            # Check and display the result in the details pane
            if explain_result.returncode == 0:
                explanation = explain_result.stdout
                combined_text = (
                    f"{self.detailed_text}\n\n\n"  # Add two blank lines for spacing
                    f"## Explanation for {vuln_id}:\n\n"
                    f"{explanation}"
                )
                combined_text = format_urls_as_markdown(combined_text)
                self.details_display.update(combined_text)
                self.debug_log(f"Displaying explanation for {vuln_id}")
            else:
                self.details_display.update(f"# Error\n\nFailed to explain {vuln_id}.\n\nError: {explain_result.stderr}")
                self.debug_log(f"Error explaining {vuln_id}: {explain_result.stderr}")
                
        except Exception as e:
            self.details_display.update(f"# Error\n\nError explaining {vuln_id}: {e}")
            self.debug_log(f"Exception in explain_vulnerability: {e}")


if __name__ == "__main__":
    sbom_file = sys.argv[1] if len(sys.argv) > 1 else sys.exit()
    if not is_grype_installed():
        if prompt_install_grype():
            install_grype()
        else:
            print(
                "The grype binary is not located in $PATH and the option to install was deferred."
            )
            sys.exit(0)
    Grummage(sbom_file).run()
