11import streamlit as st
22import asyncio
33import logging
4- from typing import List , Optional
5- from pydantic import BaseModel , Field
6- from models import TestCase
4+ from models import PipelineRequest
75from pipeline import CodeGenerationPipeline
86from dotenv import load_dotenv
97import os
1412
1513load_dotenv ()
1614
17- api_key = os .getenv ("GROQ_API_KEY" )
18- if not api_key :
19- raise ValueError ("GROQ API KEY is not set" )
15+ # api_key = os.getenv("GROQ_API_KEY")
16+ # if not api_key:
17+ # raise ValueError("GROQ API KEY is not set")
2018
2119# Configure logging
2220logging .basicConfig (level = logging .INFO )
5755 "gemma2-9b-it"
5856]
5957
60-
61- class PipelineRequest (BaseModel ):
62- model : str
63- language : str
64- question : str
65- explanation : Optional [str ] = None
66- user_input : Optional [str ] = None
67- max_iterations : int = 3
68- generate_test_cases : bool = True
69- test_cases : Optional [List [TestCase ]] = []
70-
7158@app .post ("/run_pipeline" )
7259async def run_pipeline (data : PipelineRequest ):
7360 try :
@@ -90,7 +77,7 @@ async def run_pipeline(data: PipelineRequest):
9077 )
9178
9279 pipeline = CodeGenerationPipeline (
93- api_key = os . getenv ( "GROQ_API_KEY" ) ,
80+ api_key = data . api_key ,
9481 base_url = "https://api.groq.com/openai/v1" ,
9582 max_iterations = data .max_iterations
9683 )
@@ -135,179 +122,4 @@ async def run_pipeline(data: PipelineRequest):
135122 raise HTTPException (status_code = 500 , detail = str (e ))
136123
137124if __name__ == "__main__" :
138- uvicorn .run (app , host = "0.0.0.0" , port = 8000 , reload = True )
139-
140-
141- # st.title("CodeCraft")
142-
143- # selected_model = st.selectbox(
144- # 'Select a model:',
145- # model_ids
146- # )
147-
148- # # Main input fields
149- # question = st.text_area(
150- # "Question",
151- # placeholder="Write a function that returns the sum of two numbers",
152- # help="Enter the coding question or task."
153- # )
154-
155- # explanation = st.text_area(
156- # "Explanation",
157- # placeholder="Create a simple function that takes two numbers as input and returns their sum.",
158- # help="Provide additional context or explanation."
159- # )
160-
161- # user_input = st.text_input(
162- # "User Input (for testing)",
163- # placeholder="2 3",
164- # help="Input to test the generated code."
165- # )
166-
167-
168- # # Language dropdown
169- # language = st.selectbox("Language", list(LANGUAGE_MAPPING.keys()), help="Select the programming language for the code.")
170- # language_code = LANGUAGE_MAPPING[language]
171-
172- # # Slider for max iterations
173- # max_iterations = st.slider("Max Iterations", min_value=1, max_value=5, value=3, help="Maximum number of iterations for the pipeline.")
174-
175- # # Test cases input
176- # st.header("Test Cases")
177-
178- # # Option to generate test cases automatically
179- # generate_test_cases = st.checkbox("Generate Test Cases Automatically", value=True, help="Automatically generate test cases based on the problem statement.")
180-
181- # # Initialize session state to store test cases
182- # if "test_cases" not in st.session_state:
183- # st.session_state.test_cases = [{"input": "", "expected_output": ""}] # Start with one test case
184-
185- # # Function to add a new test case
186- # def add_test_case():
187- # st.session_state.test_cases.append({"input": "", "expected_output": ""})
188-
189- # # Collapsible section for test cases
190- # with st.expander("Manage Test Cases", expanded=True):
191- # if not generate_test_cases:
192- # # Display test cases
193- # test_cases = []
194- # for i, test_case in enumerate(st.session_state.test_cases):
195- # st.subheader(f"Test Case {i + 1}")
196- # input_data = st.text_input(f"Input {i + 1}", value=test_case["input"], key=f"input_{i}", help="Input for the test case.")
197- # expected_output = st.text_input(f"Expected Output {i + 1}", value=test_case["expected_output"], key=f"expected_output_{i}", help="Expected output for the test case.")
198- # test_cases.append(TestCase(input=input_data, expected_output=expected_output))
199-
200- # # Button to add more test cases
201- # st.button("Add Test Case", on_click=add_test_case)
202- # else:
203- # test_cases = []
204-
205- # # Run pipeline button
206- # if st.button("Run Pipeline"):
207- # pipeline = CodeGenerationPipeline(
208- # api_key=api_key,
209- # base_url="https://api.groq.com/openai/v1",
210- # max_iterations=max_iterations
211- # )
212-
213- # # Run the pipeline asynchronously
214- # async def run_pipeline_async():
215- # return await pipeline.run_pipeline(
216- # model = selected_model,
217- # language=language_code,
218- # question=question,
219- # test_cases=test_cases,
220- # explanation=explanation,
221- # user_input=user_input
222- # )
223-
224- # # Display a spinner while the pipeline is running
225- # with st.spinner("Running pipeline..."):
226- # # Save the question to the database
227- # question = save_question(
228- # model=selected_model,
229- # question_text=question,
230- # explanation=explanation,
231- # user_input=user_input,
232- # language=language_code,
233- # max_iterations=max_iterations
234- # )
235- # result = asyncio.run(run_pipeline_async())
236-
237- # # Display results
238- # st.header("Results")
239- # st.subheader("Chain of Thought")
240- # for i, step in enumerate(result.cot, 1):
241- # st.write(f"{i}. {step}")
242- # st.subheader("Final Code")
243- # st.code(result.final_code, language=language_code)
244-
245- # st.subheader("Execution Result")
246- # st.text(result.final_result.output)
247-
248- # st.subheader("Test Case Results")
249- # for i, test_result in enumerate(result.test_results):
250- # st.write(f"### Test Case {i + 1}")
251- # st.write(f"**Input:** {test_result.input}") # Use dot notation
252- # st.write(f"**Expected Output:** {test_result.expected_output}") # Use dot notation
253- # st.write(f"**Actual Output:** {test_result.actual_output}") # Use dot notation
254- # st.write(f"**Time:** {test_result.time}") # Use dot notation
255- # st.write(f"**Memory:** {test_result.memory}") # Use dot notation
256- # if test_result.stderror: # Use dot notation
257- # st.write(f"**Standard Error:** {test_result.stderror}") # Use dot notation
258- # if test_result.compiler_errors:
259- # st.write(f"**Compiler Errors:** {test_case.compiler_errors}")
260- # st.write("---")
261-
262- # st.subheader("Pipeline Metadata")
263- # st.write(f"**Total Iterations:** {result.iterations}")
264- # st.write(f"**Success:** {result.success}")
265-
266- # # Display iteration history
267- # st.subheader("Iteration History")
268- # for history in result.history:
269- # # Save the iteration to the database
270- # iteration = save_iteration(
271- # question_id=question.id,
272- # iteration_number=history.iteration,
273- # chain_of_thought=history.chain_of_thought,
274- # generated_code=history.code,
275- # success=all(test_case.passed for test_case in history.test_results)
276- # )
277- # st.write(f"### Iteration {history.iteration}")
278- # st.write("**Chain of Thought:**")
279- # for i, step in enumerate(history.chain_of_thought, 1):
280- # st.write(f"{i}. {step}")
281- # st.write("**Generated Code:**")
282- # st.code(history.code, language=language_code)
283- # st.write("**Execution Result:**")
284- # st.text(history.execution_result.output)
285- # if history.execution_result.stderror:
286- # st.write(f"**Standard Error:** {history.execution_result.stderror}")
287- # if history.execution_result.compiler_errors:
288- # st.write(f"**Compiler Errors:** {history.execution_result.compiler_errors}")
289- # st.write("**Test Case Results:**")
290- # for i, test_result in enumerate(history.test_results):
291- # save_test_case_results(
292- # iteration_id=iteration.id,
293- # input_data=test_result.input,
294- # expected_output=test_result.expected_output,
295- # actual_output=test_result.actual_output,
296- # execution_time=test_result.time,
297- # memory_usage=test_result.memory,
298- # stderror=test_result.stderror or "",
299- # compiler_errors=test_result.compiler_errors or "",
300- # passed=test_result.passed
301- # )
302- # st.write(f"#### Test Case {i + 1}")
303- # st.write(f"**Input:** {test_result.input}")
304- # st.write(f"**Expected Output:** {test_result.expected_output}")
305- # st.write(f"**Actual Output:** {test_result.actual_output}")
306- # st.write(f"**Time:** {test_result.time}")
307- # st.write(f"**Memory:** {test_result.memory}")
308- # st.write(f"**Passed:** {test_result.passed}")
309- # if test_result.stderror:
310- # st.write(f"**Standard Error:** {test_result.stderror}")
311- # if test_result.compiler_errors:
312- # st.write(f"**Compiler Errors:** {test_result.compiler_errors}")
313- # st.write("---")
125+ uvicorn .run (app , host = "0.0.0.0" , port = 8000 , reload = True )
0 commit comments