Skip to content

Commit

Permalink
Fixed issue with generate_prompt_context
Browse files Browse the repository at this point in the history
  • Loading branch information
coordt committed Oct 7, 2023
1 parent d833c0a commit 4756660
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 43 deletions.
24 changes: 17 additions & 7 deletions cookie_composer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
import tempfile
from collections import OrderedDict
from enum import Enum
from functools import reduce
from pathlib import Path
from typing import Any, Dict, List, MutableMapping, Optional

import click
from cookiecutter.config import get_user_config
from cookiecutter.generate import generate_files
from cookiecutter.generate import apply_overwrites_to_context, generate_files
from cookiecutter.main import _patch_import_path_for_repo
from pydantic import BaseModel, DirectoryPath, Field, model_validator

Expand All @@ -23,6 +22,7 @@
from cookie_composer.merge_files import MERGE_FUNCTIONS

from .templates.types import Template
from .utils import echo

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -87,15 +87,15 @@ def layer_name(self) -> str:
"""The name of the template layer."""
return self.template.name

def generate_prompt_context(
def generate_context(
self,
default_context: MutableMapping[str, Any],
) -> OrderedDict:
"""
Get the context for prompting the user for values.
The order of precedence is:
1.` initial_context` from the composition or command-line
1. `initial_context` from the composition or command-line
2. `default_context` from the user_config
3. `raw context` from the template
Expand All @@ -122,7 +122,15 @@ def generate_prompt_context(

# This pulls in the template context and overrides the values with the user config defaults
# and the defaults specified in the layer.
return OrderedDict(reduce(comprehensive_merge, [raw_context, user_context, layer_initial_context], {}))
if default_context:
try:
apply_overwrites_to_context(raw_context, default_context)
except ValueError as error:
echo(f"Invalid user default received: {error}")
if layer_initial_context:
apply_overwrites_to_context(raw_context, layer_initial_context)

return OrderedDict(raw_context)


class RenderedLayer(BaseModel):
Expand Down Expand Up @@ -239,10 +247,12 @@ def render_layer(
full_context = full_context or Context()
user_config = get_user_config(config_file=None, default_config=False)
repo_dir = layer_config.template.cached_path

default_context = user_config.get("default_context", {})
context_for_prompting = layer_config.generate_prompt_context(
context = layer_config.generate_context(
default_context=default_context,
)
context_for_prompting = {k: v for k, v in context.items() if k not in full_context}
layer_context = get_layer_context(
template_repo_dir=repo_dir,
context_for_prompting=context_for_prompting,
Expand Down Expand Up @@ -315,7 +325,7 @@ def get_layer_context(
with import_patch:
prompted_context = prompt_for_config(context_for_prompting, full_context, initial_context, no_input)
context_for_prompting.update(prompted_context)
return dict(context_for_prompting)
return context_for_prompting


def render_layers(
Expand Down
49 changes: 13 additions & 36 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,15 +299,15 @@ def test_layer_config_generate_prompt_context(
4. raw context from the template
"""
layer_conf = LayerConfig(template=template_two, initial_context=initial_context, no_input=True)
context = layer_conf.generate_prompt_context(default_context=default_context)
context = layer_conf.generate_context(default_context=default_context)
assert context == expected


def test_get_layer_context(fixtures_path: Path, template_one: Template, tmp_path: Path):
layer_conf = LayerConfig(template=template_one, no_input=True)
user_config = get_user_config(config_file=None, default_config=False)

prompt_context = layer_conf.generate_prompt_context(user_config)
prompt_context = layer_conf.generate_context(user_config)
context = layers.get_layer_context(
template_one.repo.cached_source,
prompt_context,
Expand All @@ -317,15 +317,7 @@ def test_get_layer_context(fixtures_path: Path, template_one: Template, tmp_path
)
assert context == {
"_requirements": {"bar": ">=5.0.0", "foo": ""},
"abbreviations": {
"bb": "https://bitbucket.org/{0}",
"gh": "https://github.com/{0}.git",
"gl": "https://gitlab.com/{0}.git",
},
"cookiecutters_dir": str(tmp_path.joinpath("home/.cookiecutters")),
"default_context": {},
"project_name": "Fake Project Template",
"replay_dir": str(tmp_path.joinpath("home/.cookiecutter_replay")),
"repo_name": "fake-project-template",
"repo_slug": "fake-project-template",
"service_name": "foo",
Expand All @@ -348,39 +340,24 @@ def test_get_layer_context_with_extra(fixtures_path: Path, template_two: Templat
}
)
)
prompt_context = layer_conf.generate_prompt_context(user_config)
prompt_context = layer_conf.generate_context(user_config)
layer_context = layers.get_layer_context(
template_two.repo.cached_source,
prompt_context,
layer_conf.initial_context or {},
full_context,
no_input=layer_conf.no_input,
)
assert (
layer_context
== Context(
OrderedDict(
{
"project_name": "Fake Project Template2",
"repo_name": "fake-project-template2",
"project_slug": "fake-project-template-two",
"_requirements": OrderedDict([("bar", ">=5.0.0"), ("baz", "")]),
"lower_project_name": "fake project template2",
"repo_slug": "fake-project-template-two",
"service_name": "foo",
}
),
OrderedDict(
{
"project_name": "Fake Project Template2",
"repo_name": "fake-project-template2",
"repo_slug": "fake-project-template-two",
"service_name": "foo",
"_requirements": {"foo": "", "bar": ">=5.0.0"},
}
),
).flatten()
)
expected = {
"project_name": "Fake Project Template2",
"repo_name": "fake-project-template2",
"project_slug": "fake-project-template-two",
"_requirements": OrderedDict([("bar", ">=5.0.0"), ("baz", "")]),
"lower_project_name": "fake project template2",
"repo_slug": "fake-project-template-two",
"service_name": "foo",
}
assert layer_context == expected


@pytest.mark.parametrize(
Expand Down

0 comments on commit 4756660

Please sign in to comment.