Skip to content

Commit 5c30443

Browse files
committed
dependency inject app configuration
1 parent 4473a88 commit 5c30443

25 files changed

+379
-245
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ dependencies = [
5757
"pytest-xdist~=3.6",
5858
"pytest-asyncio~=0.24",
5959
"pytest-httpx~=0.30",
60+
"tomli-w>=1.2.0",
6061
]
6162

6263
[[tool.hatch.envs.hatch-test.matrix]]

src/shelloracle/__main__.py

Lines changed: 2 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,4 @@
1-
import argparse
2-
import logging
3-
import sys
4-
from importlib.metadata import version
5-
6-
from shelloracle import shelloracle
7-
from shelloracle.config import initialize_config
8-
from shelloracle.settings import Settings
9-
from shelloracle.tty_log_handler import TtyLogHandler
10-
11-
logger = logging.getLogger(__name__)
12-
13-
14-
def configure_logging():
15-
root_logger = logging.getLogger()
16-
root_logger.setLevel(logging.DEBUG)
17-
18-
file_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s")
19-
file_handler = logging.FileHandler(Settings.shelloracle_home / "shelloracle.log")
20-
file_handler.setLevel(logging.DEBUG)
21-
file_handler.setFormatter(file_formatter)
22-
23-
tty_formatter = logging.Formatter("%(message)s")
24-
tty_handler = TtyLogHandler()
25-
tty_handler.setLevel(logging.WARNING)
26-
tty_handler.setFormatter(tty_formatter)
27-
28-
root_logger.addHandler(file_handler)
29-
root_logger.addHandler(tty_handler)
30-
31-
32-
def configure():
33-
# nest this import in a function to avoid expensive module loads
34-
from shelloracle.bootstrap import bootstrap_shelloracle
35-
36-
bootstrap_shelloracle()
37-
38-
39-
def parse_args() -> argparse.Namespace:
40-
parser = argparse.ArgumentParser()
41-
parser.add_argument("--version", action="version", version=f"{__package__} {version(__package__)}")
42-
43-
subparsers = parser.add_subparsers()
44-
configure_subparser = subparsers.add_parser("configure", help=f"install {__package__} keybindings")
45-
configure_subparser.set_defaults(action=configure)
46-
47-
return parser.parse_args()
48-
49-
50-
def main() -> None:
51-
args = parse_args()
52-
configure_logging()
53-
54-
if action := getattr(args, "action", None):
55-
action()
56-
sys.exit(0)
57-
58-
try:
59-
initialize_config()
60-
except FileNotFoundError:
61-
logger.warning("ShellOracle configuration not found. Run `shor configure` to initialize.")
62-
sys.exit(1)
63-
64-
shelloracle.cli()
65-
66-
671
if __name__ == "__main__":
2+
from shelloracle.cli import main
3+
684
main()

src/shelloracle/bootstrap.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from prompt_toolkit.shortcuts import confirm
1313

1414
from shelloracle.providers import Provider, Setting, get_provider, list_providers
15-
from shelloracle.settings import Settings
1615

1716
if TYPE_CHECKING:
1817
from collections.abc import Iterator, Sequence
@@ -104,7 +103,7 @@ def correct_name_setting():
104103
yield from correct_name_setting()
105104

106105

107-
def write_shelloracle_config(provider: type[Provider], settings: dict[str, Any]) -> None:
106+
def write_shelloracle_config(provider: type[Provider], settings: dict[str, Any], config_path: Path) -> None:
108107
config = tomlkit.document()
109108

110109
shor_table = tomlkit.table()
@@ -119,8 +118,7 @@ def write_shelloracle_config(provider: type[Provider], settings: dict[str, Any])
119118
provider_configuration_table.add(setting, value)
120119
provider_table.add(provider.name, provider_configuration_table)
121120

122-
filepath = Settings.shelloracle_home / "config.toml"
123-
with filepath.open("w") as config_file:
121+
with config_path.open("w") as config_file:
124122
tomlkit.dump(config, config_file)
125123

126124

@@ -164,7 +162,7 @@ def user_select_provider() -> type[Provider]:
164162
return get_provider(provider_name)
165163

166164

167-
def bootstrap_shelloracle() -> None:
165+
def bootstrap_shelloracle(config_path: Path) -> None:
168166
try:
169167
provider = user_select_provider()
170168
settings = user_configure_settings(provider)
@@ -173,5 +171,5 @@ def bootstrap_shelloracle() -> None:
173171
return
174172
except KeyboardInterrupt:
175173
return
176-
write_shelloracle_config(provider, settings)
174+
write_shelloracle_config(provider, settings, config_path)
177175
install_keybindings()

src/shelloracle/cli/__init__.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,24 @@
11
import logging
22
import sys
3-
from importlib.metadata import version
3+
from pathlib import Path
44

55
import click
66

77
from shelloracle import shelloracle
8+
from shelloracle.cli.application import Application
89
from shelloracle.cli.config import config
9-
from shelloracle.config import initialize_config
10-
from shelloracle.settings import Settings
10+
from shelloracle.config import Configuration
1111
from shelloracle.tty_log_handler import TtyLogHandler
1212

1313
logger = logging.getLogger(__name__)
1414

1515

16-
def configure_logging():
16+
def configure_logging(log_path: Path):
1717
root_logger = logging.getLogger()
1818
root_logger.setLevel(logging.DEBUG)
1919

2020
file_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s")
21-
file_handler = logging.FileHandler(Settings.shelloracle_home / "shelloracle.log")
21+
file_handler = logging.FileHandler(log_path)
2222
file_handler.setLevel(logging.DEBUG)
2323
file_handler.setFormatter(file_formatter)
2424

@@ -32,21 +32,26 @@ def configure_logging():
3232

3333

3434
@click.group(invoke_without_command=True)
35-
@click.version_option(version=version("shelloracle"))
35+
@click.version_option()
3636
@click.pass_context
3737
def cli(ctx):
3838
"""ShellOracle command line interface."""
39-
configure_logging()
39+
app = Application()
40+
configure_logging(app.log_path)
4041

41-
# If no subcommand is invoked, run the main CLI
42-
if ctx.invoked_subcommand is None:
43-
try:
44-
initialize_config()
45-
except FileNotFoundError:
46-
logger.warning("ShellOracle configuration not found. Run `shor config init` to initialize.")
47-
sys.exit(1)
42+
try:
43+
app.configuration = Configuration.from_file(app.config_path)
44+
except FileNotFoundError:
45+
logger.warning("Configuration not found. Run `shor config init` to initialize.")
46+
sys.exit(1)
4847

49-
shelloracle.cli()
48+
ctx.obj = app
49+
50+
if ctx.invoked_subcommand is not None:
51+
# If no subcommand is invoked, run the main CLI
52+
return
53+
54+
shelloracle.cli(app)
5055

5156

5257
cli.add_command(config)

src/shelloracle/cli/application.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from pathlib import Path
2+
3+
from shelloracle.config import Configuration
4+
5+
shelloracle_home = Path.home() / ".shelloracle"
6+
shelloracle_home.mkdir(exist_ok=True)
7+
8+
9+
class Application:
10+
configuration: Configuration
11+
12+
def __init__(self):
13+
self.config_path = shelloracle_home / "config.toml"
14+
self.log_path = shelloracle_home / "shelloracle.log"

src/shelloracle/cli/config/init.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import click
22

3+
from shelloracle.cli import Application
4+
35

46
@click.command()
5-
def init():
7+
@click.pass_obj
8+
def init(app: Application):
69
"""Install shelloracle keybindings."""
710
# nest this import in a function to avoid expensive module loads
811
from shelloracle.bootstrap import bootstrap_shelloracle
912

10-
bootstrap_shelloracle()
13+
bootstrap_shelloracle(app.config_path)

src/shelloracle/config.py

Lines changed: 13 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77

88
from yaspin.spinners import SPINNERS_DATA
99

10-
from shelloracle.settings import Settings
11-
1210
if TYPE_CHECKING:
1311
from pathlib import Path
1412

13+
1514
if sys.version_info < (3, 11):
1615
import tomli as tomllib
1716
else:
@@ -21,15 +20,13 @@
2120

2221

2322
class Configuration(Mapping):
24-
def __init__(self, filepath: Path) -> None:
23+
def __init__(self, config: dict[str, Any]) -> None:
2524
"""ShellOracle application configuration
2625
27-
:param filepath: Path to the configuration file
26+
:param config: configuration dict
2827
:raises FileNotFoundError: if the configuration file does not exist
2928
"""
30-
self.filepath = filepath
31-
with filepath.open("rb") as config_file:
32-
self._config = tomllib.load(config_file)
29+
self._config = config
3330

3431
def __getitem__(self, key: str) -> Any:
3532
return self._config[key]
@@ -46,6 +43,10 @@ def __str__(self):
4643
def __repr__(self) -> str:
4744
return str(self)
4845

46+
@property
47+
def raw_config(self) -> dict[str, Any]:
48+
return self._config
49+
4950
@property
5051
def provider(self) -> str:
5152
return self["shelloracle"]["provider"]
@@ -60,31 +61,8 @@ def spinner_style(self) -> str | None:
6061
return None
6162
return style
6263

63-
64-
_config: Configuration | None = None
65-
66-
67-
def initialize_config() -> None:
68-
"""Initialize the configuration file
69-
70-
:raises RuntimeError: if the config is already initialized
71-
:raises FileNotFoundError: if the config file is not found
72-
"""
73-
global _config # noqa: PLW0603
74-
if _config:
75-
msg = "Configuration already initialized"
76-
raise RuntimeError(msg)
77-
filepath = Settings.shelloracle_home / "config.toml"
78-
_config = Configuration(filepath)
79-
80-
81-
def get_config() -> Configuration:
82-
"""Returns the global configuration object.
83-
84-
:return: the global configuration
85-
:raises RuntimeError: if the configuration is not initialized
86-
"""
87-
if _config is None:
88-
msg = "Configuration not initialized"
89-
raise RuntimeError(msg)
90-
return _config
64+
@classmethod
65+
def from_file(cls, filepath: Path):
66+
with filepath.open("rb") as config_file:
67+
config = tomllib.load(config_file)
68+
return cls(config)

src/shelloracle/providers/__init__.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from __future__ import annotations
22

33
from abc import abstractmethod
4-
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, runtime_checkable
5-
6-
from shelloracle.config import get_config
4+
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar
75

86
if TYPE_CHECKING:
97
from collections.abc import AsyncIterator
108

9+
from shelloracle.config import Configuration
10+
1111
system_prompt = (
1212
"Based on the following user description, generate a corresponding shell command. Focus solely "
1313
"on interpreting the requirements and translating them into a single, executable Bash command. "
@@ -22,7 +22,6 @@ class ProviderError(Exception):
2222
"""LLM providers raise this error to gracefully indicate something has gone wrong."""
2323

2424

25-
@runtime_checkable
2625
class Provider(Protocol):
2726
"""
2827
LLM Provider Protocol
@@ -31,6 +30,10 @@ class Provider(Protocol):
3130
"""
3231

3332
name: str
33+
config: Configuration
34+
35+
def __init__(self, config: Configuration) -> None:
36+
self.config = config
3437

3538
@abstractmethod
3639
def generate(self, prompt: str) -> AsyncIterator[str]:
@@ -64,9 +67,8 @@ def __get__(self, instance: Provider, owner: type[Provider]) -> T:
6467
# inspect.get_members from determining the object type
6568
msg = "Settings must be accessed through a provider instance."
6669
raise AttributeError(msg)
67-
config = get_config()
6870
try:
69-
return config["provider"][owner.name][self.name]
71+
return instance.config["provider"][owner.name][self.name]
7072
except KeyError:
7173
if self.default is None:
7274
raise

src/shelloracle/providers/deepseek.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ class Deepseek(Provider):
1111
api_key = Setting(default="")
1212
model = Setting(default="deepseek-chat")
1313

14-
def __init__(self):
14+
def __init__(self, *args, **kwargs):
15+
super().__init__(*args, **kwargs)
1516
if not self.api_key:
1617
msg = "No API key provided"
1718
raise ProviderError(msg)

src/shelloracle/providers/google.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ class Google(Provider):
1111
api_key = Setting(default="")
1212
model = Setting(default="gemini-2.0-flash") # Assuming a default model name
1313

14-
def __init__(self):
14+
def __init__(self, *args, **kwargs):
15+
super().__init__(*args, **kwargs)
1516
if not self.api_key:
1617
msg = "No API key provided"
1718
raise ProviderError(msg)

src/shelloracle/providers/localai.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ class LocalAI(Provider):
1616
def endpoint(self) -> str:
1717
return f"http://{self.host}:{self.port}"
1818

19-
def __init__(self):
19+
def __init__(self, *args, **kwargs):
20+
super().__init__(*args, **kwargs)
2021
# Use a placeholder API key so the client will work
2122
self.client = AsyncOpenAI(api_key="sk-xxx", base_url=self.endpoint)
2223

src/shelloracle/providers/openai.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ class OpenAI(Provider):
1111
api_key = Setting(default="")
1212
model = Setting(default="gpt-3.5-turbo")
1313

14-
def __init__(self):
14+
def __init__(self, *args, **kwargs):
15+
super().__init__(*args, **kwargs)
1516
if not self.api_key:
1617
msg = "No API key provided"
1718
raise ProviderError(msg)

src/shelloracle/providers/openai_compat.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ class OpenAICompat(Provider):
1212
api_key = Setting(default="")
1313
model = Setting(default="")
1414

15-
def __init__(self):
15+
def __init__(self, *args, **kwargs):
16+
super().__init__(*args, **kwargs)
1617
if not self.api_key:
1718
msg = "No API key provided. Use a dummy placeholder if no key is required"
1819
raise ProviderError(msg)

0 commit comments

Comments
 (0)