Skip to content

Commit 1629227

Browse files
zsimjeeirgolicCalebCourier
authored
Typing and XML removal (#414)
* Runner: dataclass => pydantic BaseModel * guard: Add overload typedefs for __call__ and parse * Correctly type guard.py and references * run: Fully typed * guard: Fix parse type overload * cli: one typehint fix * guard: Raise RuntimeError if num_reasks is None after calling configure * run: Reconsolidate rebased parse call * run: Add type expecting reasks * Remove <pydantic> plumbing * rail: Type * guard: Propagate typehints * base_prompt: Correctly return None on no match * Type schema, validators, datatypes * format * rail: Default input schema to None * Allow metadata to be None instead of default dict * Add coro checks to SaliencyCheck validator * Refactor MockValidator to create_mock_validator This now creates a type dynamically instead of just an instance * datatypes: Import Self from typing_extensions instead of typing (python3.11+) * Add typing-extensions dependency This is already included by some other dependency, but we should also explicitly require it * Remove PydanticReask (old pydantic plumbing) * Fix mocked openai embeddings response * analytics * embedding: Typing * document_store: Typing * vectordb: Typing * docs_utils: Typing * json_utils: Typing * logs_utils: Typing * reask_utils: Typing * pydantic_utils: Typing * format * sql_utils: Typing * parsing_utils: Typing * utils/misc: Typing * pydantic_utils: Fixup * add missing imports back in * fix tests * lint fixes * ignore tests when calculating coverage * match js pattern to mkdocs docs * add GA * directly add postgot and pipedrive keys * GA stream directly in yml * Update canonical uri * Formatter.parse -> Template.get_identifiers * polyfill get_identifiers for older python versions * test_get_template_variables * autoformat * swap spy * simplify test * correct minimum version * remove unused impot * use safe_substitute to guard against monetary format i.e. $5 * ignore pyenv version * test prompt.escape * lint fix * ignore lint error on line * schema.FormatAttr: Remove internal XML dependence * datatypes: Store optional as param in datatype * text2sql: Typing * Makefile: Add type target * guard: Type return as Any * llm_providers: Typing * JsonSchema.parse: Small typing * validators.provenancev0: Rewrite to satisfy typing * ci: Add pyright job * setup: Add pyright to dev dependencies * Ignore handled missing imports * pydantic_utils: Return instead of pass (bug?) * setup: Require dev dependency lxml-stubs (for typing) * validators.ProvenanceV1: Rewrite for typing * Fix xml attr string casts * json_utils: Take datatype dict instead of xml element * datatypes: Abstract out name and description * datatypes: Verify metadata over datatype * Split validators into separate package * schema/reask_utils: Migrate from internal XML dependence * tests: Fix * Rewrite pruned tree test * format * Refactor removeprefix for python 3.8 * Abstract out xml element string cast * add a few UTs * test_async: Test StringSchema.async_validate * test_xml_utils: Add test for test_xml_element_cast * test manifest * schema: Remove unused func * add pydantic chat integration test * test_llm_providers: Add openai chat tests with basemodel * Fix some type errors * datatypes: set_children => set_children_from_xml * datatypes: Remove iter methods * datatypes: Remove xml element property * schema: Simplify from_xml construction * test_reask_utils: Amend test with correct validator example * lint * circ references, semi-passing tests * correct string default typing * date_format fix * run.call: correctly pass output as llm_response * guard.parse: Pop kwargs off api initialization (as it used to do) * mock_llm_outputs: Add tomato cheese pizza example * test_run: formatting * fix regex validator * fix validator test type sussing * lint * unused import removed * typing 50% fixed * text completion * ignore pyright warning on np * lint * fix numpy typing issue * Check for match type on init * merge latest from public main * ignore tpying on optional packages * fix import path and typing on detect_secrets package --------- Co-authored-by: Rafael Irgolic <[email protected]> Co-authored-by: rafael <[email protected]> Co-authored-by: Caleb Courier <[email protected]>
1 parent eded75d commit 1629227

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+2440
-1711
lines changed

.github/workflows/ci.yml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,33 @@ jobs:
4242
run: |
4343
make lint
4444
45+
Typing:
46+
runs-on: ubuntu-latest
47+
strategy:
48+
matrix:
49+
python-version: ['3.8', '3.9', '3.10', '3.11']
50+
51+
steps:
52+
- uses: actions/checkout@v2
53+
- name: Set up Python ${{ matrix.python-version }}
54+
uses: actions/setup-python@v2
55+
with:
56+
python-version: ${{ matrix.python-version }}
57+
58+
- uses: actions/cache@v2
59+
with:
60+
path: ~/.cache/pip
61+
key: ${{ runner.os }}-pip
62+
63+
- name: Install Dependencies
64+
run: |
65+
python -m pip install --upgrade pip
66+
make dev
67+
68+
- name: Static analysis with pyright
69+
run: |
70+
make type
71+
4572
Pytests:
4673
runs-on: ubuntu-latest
4774
strategy:

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@ scratch/
2323
test.db
2424
test.index
2525
htmlcov
26+
.python-version

Makefile

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ autoformat:
55
isort --atomic guardrails/ tests/
66
docformatter --in-place --recursive guardrails tests
77

8+
type:
9+
pyright guardrails/
10+
811
lint:
912
isort -c guardrails/ tests/
1013
black guardrails/ tests/ --check
@@ -36,4 +39,9 @@ dev:
3639
full:
3740
pip install -e ".[all]"
3841

39-
all: autoformat lint docs test
42+
all: autoformat type lint docs test
43+
44+
precommit:
45+
# pytest -x -q --no-summary
46+
pyright guardrails/
47+
make lint

codecov.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
ignore:
2-
- "guardrails/version.py"
2+
- "guardrails/version.py"
3+
- "tests"

docs/javascripts/pipedrive.js

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/javascripts/posthog.js

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

guardrails/applications/text2sql.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import asyncio
12
import json
23
import os
34
from string import Template
4-
from typing import Callable, Dict, Optional
5+
from typing import Callable, Dict, Optional, Type
56

67
import openai
78

@@ -63,14 +64,14 @@ def __init__(
6364
conn_str: str,
6465
schema_file: Optional[str] = None,
6566
examples: Optional[Dict] = None,
66-
embedding: Optional[EmbeddingBase] = OpenAIEmbedding,
67-
vector_db: Optional[VectorDBBase] = Faiss,
68-
document_store: Optional[DocumentStoreBase] = EphemeralDocumentStore,
67+
embedding: Type[EmbeddingBase] = OpenAIEmbedding,
68+
vector_db: Type[VectorDBBase] = Faiss,
69+
document_store: Type[DocumentStoreBase] = EphemeralDocumentStore,
6970
rail_spec: Optional[str] = None,
7071
rail_params: Optional[Dict] = None,
71-
example_formatter: Optional[Callable] = example_formatter,
72-
reask_prompt: Optional[str] = REASK_PROMPT,
73-
llm_api: Optional[Callable] = openai.Completion.create,
72+
example_formatter: Callable = example_formatter,
73+
reask_prompt: str = REASK_PROMPT,
74+
llm_api: Callable = openai.Completion.create,
7475
llm_api_kwargs: Optional[Dict] = None,
7576
num_relevant_examples: int = 2,
7677
):
@@ -119,7 +120,7 @@ def _init_guard(
119120
schema_file: Optional[str] = None,
120121
rail_spec: Optional[str] = None,
121122
rail_params: Optional[Dict] = None,
122-
reask_prompt: Optional[str] = REASK_PROMPT,
123+
reask_prompt: str = REASK_PROMPT,
123124
):
124125
# Initialize the Guard class
125126
if rail_spec is None:
@@ -144,9 +145,9 @@ def _init_guard(
144145
def _create_docstore_with_examples(
145146
self,
146147
examples: Optional[Dict],
147-
embedding: EmbeddingBase,
148-
vector_db: VectorDBBase,
149-
document_store: DocumentStoreBase,
148+
embedding: Type[EmbeddingBase],
149+
vector_db: Type[VectorDBBase],
150+
document_store: Type[DocumentStoreBase],
150151
) -> Optional[DocumentStoreBase]:
151152
if examples is None:
152153
return None
@@ -167,7 +168,7 @@ def _create_docstore_with_examples(
167168
def output_schema_formatter(output) -> str:
168169
return json.dumps({"generated_sql": output}, indent=4)
169170

170-
def __call__(self, text: str) -> str:
171+
def __call__(self, text: str) -> Optional[str]:
171172
"""Run text2sql on a text query and return the SQL query."""
172173

173174
if self.store is not None:
@@ -179,6 +180,12 @@ def __call__(self, text: str) -> str:
179180
else:
180181
similar_examples_prompt = ""
181182

183+
if asyncio.iscoroutinefunction(self.llm_api):
184+
raise ValueError(
185+
"Async API is not supported in Text2SQL application. "
186+
"Please use a synchronous API."
187+
)
188+
182189
try:
183190
output = self.guard(
184191
self.llm_api,
@@ -188,7 +195,11 @@ def __call__(self, text: str) -> str:
188195
"db_info": str(self.sql_schema),
189196
},
190197
**self.llm_api_kwargs,
191-
)[1]["generated_sql"]
198+
)[ # type: ignore
199+
1
200+
][
201+
"generated_sql"
202+
]
192203
except TypeError:
193204
output = None
194205

guardrails/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def compile_rail(rail: str, out: str) -> None:
1212
raise NotImplementedError("Currently compiling rail is not supported.")
1313

1414

15-
def validate_llm_output(rail: str, llm_output: str) -> bool:
15+
def validate_llm_output(rail: str, llm_output: str) -> dict:
1616
"""Validate guardrails.yml file."""
1717
guard = Guard.from_rail(rail)
1818
result = guard.parse(llm_output)

0 commit comments

Comments
 (0)