Skip to content

Commit 92b2829

Browse files
tomwojcikVukW
andauthored
Minor feature requests (#13)
* add test for headers in response when exception is raised in view * Check if context already filled (#12) * add tests for checking if context is available * extract base for uuid plugins, add force_new_uuid, add uuid validation Co-authored-by: Viacheslav Kukushkin <vy.kukushkin@gmail.com>
1 parent 79aa8ff commit 92b2829

File tree

10 files changed

+206
-46
lines changed

10 files changed

+206
-46
lines changed

starlette_context/ctx.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections import UserDict
22
from typing import Any
33

4+
from contextvars import copy_context
45
from starlette_context import _request_scope_context_storage
56

67

@@ -32,6 +33,9 @@ def data(self) -> dict:
3233
"outside of the request-response cycle."
3334
) from e
3435

36+
def exists(self) -> bool:
37+
return _request_scope_context_storage in copy_context()
38+
3539
def copy(self) -> dict:
3640
"""
3741
Read only context data.
Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,6 @@
1-
import uuid
2-
from typing import Optional
3-
4-
from starlette.requests import Request
5-
from starlette.responses import Response
6-
71
from starlette_context.header_keys import HeaderKeys
8-
from starlette_context.plugins.plugin import Plugin
2+
from starlette_context.plugins.plugin_uuid import PluginUUIDBase
93

104

11-
class CorrelationIdPlugin(Plugin):
5+
class CorrelationIdPlugin(PluginUUIDBase):
126
key = HeaderKeys.correlation_id
13-
14-
async def extract_value_from_header_by_key(
15-
self, request: Request
16-
) -> Optional[str]:
17-
await super(
18-
CorrelationIdPlugin, self
19-
).extract_value_from_header_by_key(request)
20-
if self.value is None:
21-
self.value = uuid.uuid4().hex
22-
return self.value
23-
24-
async def enrich_response(self, response: Response) -> None:
25-
await self._add_kv_to_response_headers(response)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import uuid
2+
from typing import Optional
3+
4+
from starlette.requests import Request
5+
from starlette.responses import Response
6+
7+
from starlette_context.plugins.plugin import Plugin
8+
9+
10+
class PluginUUIDBase(Plugin):
11+
uuid_functions_mapper = {4: uuid.uuid4}
12+
13+
def __init__(self, force_new_uuid: bool = False, version: int = 4):
14+
super().__init__()
15+
if version not in self.uuid_functions_mapper:
16+
raise TypeError(f"UUID version {version} is not supported.")
17+
self.force_new_uuid = force_new_uuid
18+
self.version = version
19+
20+
def validate_uuid(self, uuid_to_validate: str) -> None:
21+
try:
22+
uuid.UUID(uuid_to_validate, version=self.version)
23+
except Exception as e:
24+
raise ValueError("Wrong uuid") from e
25+
26+
def get_new_uuid(self) -> str:
27+
func = self.uuid_functions_mapper[self.version]
28+
return func().hex
29+
30+
async def extract_value_from_header_by_key(
31+
self, request: Request
32+
) -> Optional[str]:
33+
if not self.force_new_uuid:
34+
await super().extract_value_from_header_by_key(request)
35+
36+
if self.value is None:
37+
self.value = self.get_new_uuid()
38+
39+
self.validate_uuid(self.value)
40+
41+
return self.value
42+
43+
async def enrich_response(self, response: Response) -> None:
44+
await self._add_kv_to_response_headers(response)
Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,6 @@
1-
import uuid
2-
from typing import Optional
3-
4-
from starlette.requests import Request
5-
from starlette.responses import Response
6-
71
from starlette_context.header_keys import HeaderKeys
8-
from starlette_context.plugins.plugin import Plugin
2+
from starlette_context.plugins.plugin_uuid import PluginUUIDBase
93

104

11-
class RequestIdPlugin(Plugin):
5+
class RequestIdPlugin(PluginUUIDBase):
126
key = HeaderKeys.request_id
13-
14-
async def extract_value_from_header_by_key(
15-
self, request: Request
16-
) -> Optional[str]:
17-
await super(RequestIdPlugin, self).extract_value_from_header_by_key(
18-
request
19-
)
20-
if self.value is None:
21-
self.value = uuid.uuid4().hex
22-
return self.value
23-
24-
async def enrich_response(self, response: Response) -> None:
25-
await self._add_kv_to_response_headers(response)

tests/conftest.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77

88
from starlette_context import middleware
99
from starlette_context.header_keys import HeaderKeys
10+
import uuid
1011

11-
dummy_correlation_id = "dummy_correlation_id"
12-
dummy_request_id = "dummy_request_id"
12+
dummy_correlation_id = uuid.uuid4().hex
13+
dummy_request_id = uuid.uuid4().hex
1314
dummy_user_agent = "dummy_user_agent"
1415
dummy_date = "Wed, 01 Jan 2020 04:27:12 GMT"
1516
dummy_forwarded_for = "203.0.113.19"
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import json
2+
3+
from starlette.applications import Starlette
4+
5+
from starlette.requests import Request
6+
from starlette.responses import JSONResponse, Response
7+
from starlette.testclient import TestClient
8+
9+
from starlette_context import context, middleware
10+
11+
12+
app = Starlette()
13+
app.add_middleware(middleware.ContextMiddleware)
14+
15+
16+
@app.route("/")
17+
async def index(request: Request) -> Response:
18+
return JSONResponse({"exists": context.exists()})
19+
20+
21+
client = TestClient(app)
22+
23+
24+
def test_context_existence_in_request_response_cycle():
25+
resp = client.get("/")
26+
assert json.loads(resp.content) == {"exists": True}
27+
28+
29+
def test_context_outside_of_request_response_cycle():
30+
assert context.exists() is False
31+
resp = client.get("/")
32+
assert context.exists() is False

tests/test_integration/test_plugins_response_headers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from starlette_context import context, plugins
1111
from starlette_context.middleware import ContextMiddleware
1212
from tests.conftest import dummy_correlation_id, dummy_request_id
13+
from starlette_context.header_keys import HeaderKeys
1314

1415
middleware = [
1516
Middleware(
@@ -46,8 +47,8 @@ def dt_serializator(o):
4647
def test_response_headers(headers):
4748
response = client.get("/", headers=headers)
4849
assert 2 == len(response.headers)
49-
cid_header = response.headers["x-correlation-id"]
50-
rid_header = response.headers["x-request-id"]
50+
cid_header = response.headers[HeaderKeys.correlation_id.lower()]
51+
rid_header = response.headers[HeaderKeys.request_id.lower()]
5152
assert dummy_correlation_id == cid_header
5253
assert dummy_request_id == rid_header
5354

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import json
2+
import uuid
3+
from typing import NoReturn
4+
5+
from starlette.middleware import Middleware
6+
7+
from starlette_context import plugins
8+
from starlette.applications import Starlette
9+
10+
from starlette.requests import Request
11+
from starlette.responses import JSONResponse
12+
from starlette.testclient import TestClient
13+
14+
from starlette_context import middleware
15+
from starlette_context.header_keys import HeaderKeys
16+
from starlette_context import context
17+
18+
19+
class CustomException(Exception):
20+
pass
21+
22+
23+
async def custom_exception_handler(request: Request, exc: Exception):
24+
return JSONResponse({"exception": "handled"}, headers=context.data)
25+
26+
27+
async def general_exception_handler(request: Request, exc: Exception):
28+
return JSONResponse({"exception": "handled"}, headers=context.data)
29+
30+
31+
middleware = [
32+
Middleware(
33+
middleware.ContextMiddleware, plugins=(plugins.RequestIdPlugin(),)
34+
)
35+
]
36+
exception_handlers = {
37+
CustomException: custom_exception_handler,
38+
Exception: general_exception_handler,
39+
}
40+
41+
42+
app = Starlette(exception_handlers=exception_handlers, middleware=middleware,)
43+
44+
headers = {HeaderKeys.request_id: uuid.uuid4().hex}
45+
46+
47+
@app.route("/")
48+
async def index(_) -> NoReturn:
49+
raise RuntimeError
50+
51+
52+
@app.route("/custom-exc")
53+
async def index(_) -> NoReturn:
54+
raise CustomException
55+
56+
57+
client = TestClient(app)
58+
59+
60+
def test_exception_handling_that_is_not_resulting_in_500():
61+
resp = client.get("/custom-exc", headers=headers)
62+
assert json.loads(resp.content) == {"exception": "handled"}
63+
assert (
64+
resp.headers.get(HeaderKeys.request_id)
65+
== headers[HeaderKeys.request_id]
66+
)

tests/test_unit/test_plugins/test_correlation_id.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,15 @@ async def test_process_request_for_existing_header(
2020
assert dummy_correlation_id == plugin.value
2121

2222

23+
@pytest.mark.asyncio
24+
async def test_invalid_correlation_id_uuid(
25+
plugin: plugins.CorrelationIdPlugin, mocked_request: Request
26+
):
27+
mocked_request.headers[HeaderKeys.correlation_id] = "invalid_uuid"
28+
with pytest.raises(ValueError):
29+
await plugin.process_request(mocked_request)
30+
31+
2332
@pytest.mark.asyncio
2433
async def test_process_request_for_missing_header(
2534
plugin: plugins.CorrelationIdPlugin, mocked_request: Request
@@ -52,3 +61,35 @@ async def test_enrich_response_str(
5261
dummy_correlation_id
5362
== mocked_response.headers[HeaderKeys.correlation_id]
5463
)
64+
65+
66+
def test_version_cant_map_to_function():
67+
with pytest.raises(TypeError):
68+
plugins.CorrelationIdPlugin(version=123)
69+
70+
71+
@pytest.mark.asyncio
72+
async def test_force_new_uuid(
73+
plugin: plugins.CorrelationIdPlugin,
74+
mocked_request: Request,
75+
mocked_response: Response,
76+
):
77+
plugin.force_new_uuid = True
78+
await plugin.process_request(mocked_request)
79+
await plugin.enrich_response(mocked_response)
80+
81+
assert (
82+
dummy_correlation_id
83+
!= mocked_response.headers[HeaderKeys.correlation_id]
84+
)
85+
86+
87+
@pytest.mark.asyncio
88+
async def test_uuid_validation(
89+
plugin: plugins.CorrelationIdPlugin,
90+
mocked_request: Request,
91+
mocked_response: Response,
92+
):
93+
mocked_request.headers[HeaderKeys.correlation_id] = "invalid_uuid"
94+
with pytest.raises(ValueError):
95+
await plugin.process_request(mocked_request)

tests/test_unit/test_plugins/test_request_id.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,12 @@ async def test_enrich_response_str(
4949
await plugin.enrich_response(mocked_response)
5050

5151
assert dummy_request_id == mocked_response.headers[HeaderKeys.request_id]
52+
53+
54+
@pytest.mark.asyncio
55+
async def test_invalid_request_id_uuid(
56+
plugin: plugins.RequestIdPlugin, mocked_request: Request
57+
):
58+
mocked_request.headers[HeaderKeys.request_id] = "invalid_uuid"
59+
with pytest.raises(ValueError):
60+
await plugin.process_request(mocked_request)

0 commit comments

Comments
 (0)