diff --git a/src/elevenlabs/client.py b/src/elevenlabs/client.py index 6fd75a53..63c66ebe 100644 --- a/src/elevenlabs/client.py +++ b/src/elevenlabs/client.py @@ -19,6 +19,15 @@ def get_base_url_host(base_url: str) -> str: return httpx.URL(base_url).host +def _resolve_api_key(api_key: typing.Optional[str]) -> str: + resolved = api_key if api_key is not None else os.getenv("ELEVENLABS_API_KEY") + if not resolved: + raise ValueError( + "Please pass in your ElevenLabs API Key or export ELEVENLABS_API_KEY in your environment." + ) + return resolved + + class ElevenLabs(BaseElevenLabs): """ Use this class to access the different functions within the SDK. You can instantiate any number of clients with different configuration that will propogate to these functions. @@ -47,14 +56,15 @@ def __init__( *, base_url: typing.Optional[str] = None, environment: ElevenLabsEnvironment = ElevenLabsEnvironment.PRODUCTION, - api_key: typing.Optional[str] = os.getenv("ELEVENLABS_API_KEY"), + api_key: typing.Optional[str] = None, timeout: typing.Optional[float] = 240, httpx_client: typing.Optional[httpx.Client] = None ): + resolved_api_key = _resolve_api_key(api_key) super().__init__( base_url=base_url, environment=environment, - api_key=api_key, + api_key=resolved_api_key, timeout=timeout, httpx_client=httpx_client ) @@ -93,14 +103,15 @@ def __init__( *, base_url: typing.Optional[str] = None, environment: ElevenLabsEnvironment = ElevenLabsEnvironment.PRODUCTION, - api_key: typing.Optional[str] = os.getenv("ELEVENLABS_API_KEY"), + api_key: typing.Optional[str] = None, timeout: typing.Optional[float] = 240, httpx_client: typing.Optional[httpx.AsyncClient] = None ): + resolved_api_key = _resolve_api_key(api_key) super().__init__( base_url=base_url, environment=environment, - api_key=api_key, + api_key=resolved_api_key, timeout=timeout, httpx_client=httpx_client ) diff --git a/tests/test_client_init.py b/tests/test_client_init.py new file mode 100644 index 00000000..e2d27ca9 --- /dev/null +++ b/tests/test_client_init.py @@ -0,0 +1,70 @@ +import os +from unittest import mock + +import pytest + +from elevenlabs.client import ElevenLabs, AsyncElevenLabs, _resolve_api_key + + +class TestResolveApiKey: + def test_explicit_key(self): + assert _resolve_api_key("my-key") == "my-key" + + def test_env_var_fallback(self): + with mock.patch.dict(os.environ, {"ELEVENLABS_API_KEY": "env-key"}): + assert _resolve_api_key(None) == "env-key" + + def test_explicit_key_takes_precedence_over_env(self): + with mock.patch.dict(os.environ, {"ELEVENLABS_API_KEY": "env-key"}): + assert _resolve_api_key("explicit-key") == "explicit-key" + + def test_missing_key_raises(self): + with mock.patch.dict(os.environ, {}, clear=True): + with pytest.raises(ValueError, match="API Key"): + _resolve_api_key(None) + + def test_empty_string_does_not_fall_through_to_env(self): + with mock.patch.dict(os.environ, {"ELEVENLABS_API_KEY": "env-key"}): + with pytest.raises(ValueError, match="API Key"): + _resolve_api_key("") + + def test_empty_env_var_raises(self): + with mock.patch.dict(os.environ, {"ELEVENLABS_API_KEY": ""}, clear=True): + with pytest.raises(ValueError, match="API Key"): + _resolve_api_key(None) + + +class TestElevenLabsInit: + def test_explicit_api_key(self): + client = ElevenLabs(api_key="test-key") + headers = client._client_wrapper.get_headers() + assert headers["xi-api-key"] == "test-key" + + def test_env_var_api_key(self): + with mock.patch.dict(os.environ, {"ELEVENLABS_API_KEY": "env-key"}): + client = ElevenLabs() + headers = client._client_wrapper.get_headers() + assert headers["xi-api-key"] == "env-key" + + def test_missing_api_key_raises(self): + with mock.patch.dict(os.environ, {}, clear=True): + with pytest.raises(ValueError, match="API Key"): + ElevenLabs() + + +class TestAsyncElevenLabsInit: + def test_explicit_api_key(self): + client = AsyncElevenLabs(api_key="test-key") + headers = client._client_wrapper.get_headers() + assert headers["xi-api-key"] == "test-key" + + def test_env_var_api_key(self): + with mock.patch.dict(os.environ, {"ELEVENLABS_API_KEY": "env-key"}): + client = AsyncElevenLabs() + headers = client._client_wrapper.get_headers() + assert headers["xi-api-key"] == "env-key" + + def test_missing_api_key_raises(self): + with mock.patch.dict(os.environ, {}, clear=True): + with pytest.raises(ValueError, match="API Key"): + AsyncElevenLabs() diff --git a/tests/test_convai_imports.py b/tests/test_convai_imports.py index 7cd9f0b8..08beaaa3 100644 --- a/tests/test_convai_imports.py +++ b/tests/test_convai_imports.py @@ -2,5 +2,5 @@ def test_convai_imports(): - client = ElevenLabs(api_key="") + client = ElevenLabs(api_key="test") client.conversational_ai.agents