Skip to content

Commit bee191d

Browse files
authored
Fix pyright type errors and AWS Bedrock semantic tag bug (#47)
## Summary - Resolved all 74 pyright type errors across the codebase - Fixed bug where AWS Bedrock LLM assigns semantic tags to numeric columns (LONG, BIGINT, etc.), causing Stitch jobs to fail with "Semantics on fields of type LONG are not supported" ## Changes **Type fixes:** - Aligned `ModelInfo` TypedDict across LLM providers - Fixed `WizardState.models` type hint - Added proper None handling and type narrowing throughout - Fixed except clause ordering and Callable imports **Stitch/PII fix:** - Added `NUMERIC_TYPES` filter in `stitch_tools.py` to strip semantics from numeric columns - Updated PII detection prompt to instruct LLM not to tag numeric columns ## Test plan - [x] All 626 unit tests pass - [x] Pyright reports 0 errors - [x] Ruff linting passes
1 parent 402dddc commit bee191d

32 files changed

+165
-103
lines changed

chuck_data/agent/manager.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -210,29 +210,36 @@ def process_with_tools(self, tools, max_iterations: int = 20):
210210
if response_message.tool_calls:
211211
# Add the assistant's response (requesting tool calls) to history
212212
# Convert ChatCompletionMessage to dict format for consistency
213+
tool_calls_list = []
214+
for tc in response_message.tool_calls:
215+
func = getattr(tc, "function", None)
216+
if func is not None:
217+
tool_calls_list.append(
218+
{
219+
"id": tc.id,
220+
"type": getattr(tc, "type", "function"),
221+
"function": {
222+
"name": getattr(func, "name", ""),
223+
"arguments": getattr(func, "arguments", "{}"),
224+
},
225+
}
226+
)
213227
assistant_msg = {
214228
"role": "assistant",
215229
"content": response_message.content,
216-
"tool_calls": [
217-
{
218-
"id": tc.id,
219-
"type": getattr(tc, "type", "function"),
220-
"function": {
221-
"name": tc.function.name,
222-
"arguments": tc.function.arguments,
223-
},
224-
}
225-
for tc in response_message.tool_calls
226-
],
230+
"tool_calls": tool_calls_list,
227231
}
228232
self.conversation_history.append(assistant_msg)
229233

230234
# Execute each tool call
231235
for tool_call in response_message.tool_calls:
232-
tool_name = tool_call.function.name
236+
func = getattr(tool_call, "function", None)
237+
if func is None:
238+
continue
239+
tool_name = getattr(func, "name", "")
233240
tool_id = tool_call.id
234241
try:
235-
tool_args = json.loads(tool_call.function.arguments)
242+
tool_args = json.loads(getattr(func, "arguments", "{}"))
236243
tool_result = execute_tool(
237244
self.api_client,
238245
tool_name,
@@ -276,7 +283,7 @@ def process_with_tools(self, tools, max_iterations: int = 20):
276283
continue
277284
else:
278285
# No tool calls, this is the final response
279-
final_content = response_message.content
286+
final_content = response_message.content or ""
280287
# remove all lines with any <function> tags
281288
final_content = "\n".join(
282289
line

chuck_data/agent/tool_executor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
from chuck_data.clients.databricks import (
2424
DatabricksAPIClient,
2525
) # For type hinting api_client
26-
from typing import Dict, Any, Optional, List
26+
from typing import Dict, Any, Optional, List, Callable
27+
from jsonschema.exceptions import ValidationError
2728

2829

2930
# The display_to_user utility and individual tool implementation functions
@@ -48,7 +49,7 @@ def execute_tool(
4849
api_client: Optional[DatabricksAPIClient],
4950
tool_name: str,
5051
tool_args: Dict[str, Any],
51-
output_callback: Optional[callable] = None,
52+
output_callback: Optional[Callable[..., Any]] = None,
5253
) -> Dict[str, Any]:
5354
"""Execute a tool (command) by its name with the provided arguments.
5455
@@ -87,7 +88,7 @@ def execute_tool(
8788
try:
8889
jsonschema.validate(instance=tool_args, schema=schema_to_validate)
8990
logging.debug(f"Tool arguments for '{tool_name}' validated successfully.")
90-
except jsonschema.exceptions.ValidationError as ve:
91+
except ValidationError as ve:
9192
logging.error(
9293
f"Validation error for tool '{tool_name}' args {tool_args}: {ve.message}"
9394
)

chuck_data/api_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ def upload_file(self, path, file_path=None, content=None, overwrite=False):
135135
binary_data = f.read()
136136
else:
137137
# Convert string content to bytes
138+
# content is guaranteed non-None by the validation above
139+
assert content is not None
138140
binary_data = content.encode("utf-8")
139141

140142
try:

chuck_data/clients/amperity.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import webbrowser
99
import readchar
1010
import json
11+
from typing import Optional
1112
from rich.console import Console
1213

1314
from chuck_data.config import set_amperity_token
@@ -113,7 +114,7 @@ def get_auth_status(self) -> dict:
113114
return {"state": self.state, "nonce": self.nonce, "has_token": bool(self.token)}
114115

115116
def wait_for_auth_completion(
116-
self, poll_interval: int = 1, timeout: int = None
117+
self, poll_interval: int = 1, timeout: Optional[int] = None
117118
) -> tuple[bool, str]:
118119
"""Wait for authentication to complete in a blocking manner."""
119120
if not self.nonce:

chuck_data/clients/databricks.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -724,6 +724,8 @@ def upload_file(self, path, file_path=None, content=None, overwrite=False):
724724
binary_data = f.read()
725725
else:
726726
# Convert string content to bytes
727+
# content is guaranteed non-None by the validation above
728+
assert content is not None
727729
binary_data = content.encode("utf-8")
728730

729731
try:

chuck_data/command_output.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from chuck_data.ui.table_formatter import display_table
1212

13-
from chuck_data.command_result import CommandResult
13+
from chuck_data.commands.base import CommandResult
1414
from chuck_data.ui.theme import (
1515
SUCCESS,
1616
WARNING,
@@ -132,7 +132,7 @@ def format_for_agent(result: CommandResult) -> Dict[str, Any]:
132132
}
133133

134134
# Start with a base response
135-
response = {"success": True}
135+
response: Dict[str, Any] = {"success": True}
136136

137137
# Add the message if available
138138
if result.message:

chuck_data/commands/catalog_selection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def handle_command(client: Optional[DatabricksAPIClient], **kwargs) -> CommandRe
6464
client: API client instance
6565
**kwargs: catalog (str) - catalog name, tool_output_callback (optional)
6666
"""
67-
catalog: str = kwargs.get("catalog")
67+
catalog = kwargs.get("catalog")
6868
tool_output_callback = kwargs.get("tool_output_callback")
6969

7070
if not catalog:

chuck_data/commands/job_status.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -639,7 +639,7 @@ def handle_list_jobs(client=None, **kwargs) -> CommandResult:
639639
cached_job_data = job_entry.get("job_data")
640640

641641
# If we have cached data for a terminal state, use it
642-
if cached_job_data:
642+
if cached_job_data and isinstance(cached_job_data, dict):
643643
state = (cached_job_data.get("state") or "").lower().replace(":", "")
644644
# Only use cache for terminal states (succeeded, failed, unknown)
645645
if state in ["succeeded", "success", "failed", "error", "unknown"]:

chuck_data/commands/jobs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ def handle_launch_job(client: Optional[DatabricksAPIClient], **kwargs) -> Comman
1313
client: API client instance
1414
**kwargs: config_path (str), init_script_path (str), run_name (str, optional), tool_output_callback (callable, optional)
1515
"""
16-
config_path: str = kwargs.get("config_path")
17-
init_script_path: str = kwargs.get("init_script_path")
16+
config_path = kwargs.get("config_path")
17+
init_script_path = kwargs.get("init_script_path")
1818
run_name: Optional[str] = kwargs.get("run_name")
1919
tool_output_callback = kwargs.get("tool_output_callback")
2020
policy_id: Optional[str] = kwargs.get("policy_id")

chuck_data/commands/model_selection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def handle_command(client: Optional[DatabricksAPIClient], **kwargs) -> CommandRe
2424
client: API client instance (used for Databricks provider)
2525
**kwargs: model_name (str)
2626
"""
27-
model_name: str = kwargs.get("model_name")
27+
model_name = kwargs.get("model_name")
2828
if not model_name:
2929
return CommandResult(False, message="model_name parameter is required.")
3030

@@ -45,7 +45,7 @@ def handle_command(client: Optional[DatabricksAPIClient], **kwargs) -> CommandRe
4545
models_list = provider.list_models(tool_calling_only=False)
4646

4747
# Extract model IDs (field name varies by provider)
48-
model_ids = [m.get("model_id") or m.get("name") for m in models_list]
48+
model_ids = [m.get("model_id") or m.get("name") or "" for m in models_list]
4949

5050
# Validate model exists
5151
if model_name not in model_ids:

0 commit comments

Comments
 (0)