Skip to content

Commit

Permalink
feature/cve-data-filter-flag (#643)
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanpulver authored Dec 9, 2024
1 parent 7654596 commit 25abf95
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 5 deletions.
121 changes: 118 additions & 3 deletions safety/scan/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from pathlib import Path

import json
import sys
from typing import Any, Dict, List, Optional, Set, Tuple
from typing_extensions import Annotated
Expand Down Expand Up @@ -49,7 +50,9 @@ class ScannableEcosystems(Enum):

def process_report(
obj: Any, console: Console, report: ReportModel, output: str,
save_as: Optional[Tuple[str, Path]], **kwargs
save_as: Optional[Tuple[str, Path]], detailed_output: bool = False,
filter_keys: Optional[List[str]] = None,
**kwargs
) -> Optional[str]:
"""
Processes and outputs the report based on the given parameters.
Expand All @@ -60,6 +63,8 @@ def process_report(
report (ReportModel): The report model.
output (str): The output format.
save_as (Optional[Tuple[str, Path]]): The save-as format and path.
detailed_output (bool): Whether detailed output is enabled.
filter_keys (Optional[List[str]]): Keys to filter from the JSON output.
kwargs: Additional keyword arguments.
Returns:
Expand Down Expand Up @@ -162,6 +167,12 @@ def process_report(

if output is ScanOutput.JSON or ScanOutput.is_format(output, ScanOutput.SPDX):
if output is ScanOutput.JSON:
if detailed_output:
report_to_output = add_cve_details_to_report(report_to_output, obj.project.files)

if filter_keys:
report_to_output = filter_json_keys(report_to_output, filter_keys)

kwargs = {"json": report_to_output}
else:
kwargs = {"data": report_to_output}
Expand All @@ -175,6 +186,95 @@ def process_report(
return report_url


def filter_json_keys(json_string: str, keys: List[str]) -> str:
"""
Filters the given JSON string by the specified top-level keys.
Args:
json_string (str): The JSON string to filter.
keys (List[str]): List of top-level keys to include in the output.
Returns:
str: A JSON string containing only the specified keys.
"""
report_dict = json.loads(json_string)
filtered_data = {key: report_dict[key] for key in keys if key in report_dict}
return json.dumps(filtered_data, indent=4)


def filter_valid_cves(vulnerabilities: List[Any]) -> List[Dict[str, Any]]:
"""
Filters and returns valid CVE details from a list of vulnerabilities.
Args:
vulnerabilities (List[Any]): A list of vulnerabilities, which may include invalid data types.
Returns:
List[Dict[str, Any]]: A list of filtered CVE details that are either strings or dictionaries.
"""
return [
cve for cve in vulnerabilities if isinstance(cve, str) or isinstance(cve, dict)
] #type:ignore


def sort_cve_data(cve_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Sorts CVE details by severity in descending order.
Args:
cve_data (List[Dict[str, Any]]): A list of CVE details dictionaries, each containing a 'severity' key.
Returns:
List[Dict[str, Any]]: The sorted list of CVE details, prioritized by severity (e.g., CRITICAL > HIGH > MEDIUM).
"""
severity_order = {key.name: id for (id, key) in enumerate(VulnerabilitySeverityLabels)}
return sorted(cve_data, key=lambda x: severity_order.get(x["severity"].upper(), 0), reverse=True)


def generate_cve_details(files: List[FileModel]) -> List[Dict[str, Any]]:
"""
Generate CVE details from the scanned files.
Args:
files (List[FileModel]): List of scanned file models.
Returns:
List[Dict[str, Any]]: List of CVE details sorted by severity.
"""
cve_data = []
for file in files:
for spec in file.results.get_affected_specifications():
for vuln in spec.vulnerabilities:
if vuln.CVE:
cve_data.append({
"package": spec.name,
"affected_version": str(spec.specifier),
"safety_vulnerability_id": vuln.vulnerability_id,
"CVE": filter_valid_cves(vuln.CVE),
"more_info": vuln.more_info_url,
"advisory": vuln.advisory,
"severity": vuln.severity.cvssv3.get("base_severity", "Unknown") if vuln.severity and vuln.severity.cvssv3 else "Unknown",
})
return sort_cve_data(cve_data)


def add_cve_details_to_report(report_to_output: str, files: List[FileModel]) -> str:
"""
Add CVE details to the JSON report output.
Args:
report_to_output (str): The current JSON string of the report.
files (List[FileModel]): List of scanned files containing vulnerability data.
Returns:
str: The updated JSON string with CVE details added.
"""
cve_details = generate_cve_details(files)
report_dict = json.loads(report_to_output)
report_dict["cve_details"] = cve_details
return json.dumps(report_dict)


def generate_updates_arguments() -> List:
"""
Generates a list of file types and update limits for apply fixes.
Expand Down Expand Up @@ -250,7 +350,11 @@ def scan(ctx: typer.Context,
typer.Option("--apply-fixes",
help=SCAN_APPLY_FIXES,
show_default=False)
] = False
] = False,
filter_keys: Annotated[
Optional[List[str]],
typer.Option("--filter", help="Filter output by specific top-level JSON keys.")
] = None,
):
"""
Scans a project (defaulted to the current directory) for supply-chain security and configuration issues
Expand Down Expand Up @@ -465,7 +569,18 @@ def sort_vulns_by_score(vuln: Vulnerability) -> int:
ignored_vulns_data=ignored_vulns_data
)

report_url = process_report(ctx.obj, console, report, **{**ctx.params})
report_url = process_report(
obj=ctx.obj,
console=console,
report=report,
output=output,
save_as=save_as if save_as and all(save_as) else None,
detailed_output=detailed_output,
filter_keys=filter_keys,
**{k: v for k, v in ctx.params.items() if k not in {"detailed_output", "output", "save_as", "filter_keys"}}
)


project_url = f"{SAFETY_PLATFORM_URL}{ctx.obj.project.url_path}"

if apply_updates:
Expand Down
4 changes: 2 additions & 2 deletions tests/scan/test_command.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import unittest

from unittest.mock import patch, Mock
from click.testing import CliRunner
from safety.cli import cli
Expand All @@ -13,7 +14,7 @@ def setUp(self):
self.runner = CliRunner(mix_stderr=False)
self.dirname = os.path.dirname(__file__)

def test_scan(self):
def test_scan(self):
result = self.runner.invoke(cli, ["--stage", "cicd", "scan", "--target", self.dirname, "--output", "json"])
self.assertEqual(result.exit_code, 1)

Expand All @@ -22,4 +23,3 @@ def test_scan(self):

result = self.runner.invoke(cli, ["--stage", "cicd", "scan", "--target", self.dirname, "--output", "screen"])
self.assertEqual(result.exit_code, 1)

0 comments on commit 25abf95

Please sign in to comment.