Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions evaluation/async_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(

if not os.path.exists(self.model_config_path):
raise FileNotFoundError(f"Model configuration for {model_name} not found in {model_configs_dir}")

with open(self.model_config_path, "r") as file:
self.model_config = json.load(file)

Expand Down Expand Up @@ -84,11 +84,11 @@ def __init__(

if not os.path.exists(haystack_path):
raise FileNotFoundError(f"Haystack file not found at {haystack_path}")

self.haystack_path = haystack_path

self.haystack = BookHaystack(self.haystack_path)

self.results_dir = results_dir
os.makedirs(results_dir, exist_ok=True)

Expand All @@ -103,14 +103,14 @@ def __init__(
self.seed = seed
self.prevent_duplicate = prevent_duplicate
self.distractor = distractor

self.log_placements = log_placements_dir != ""
self.log_placements_dir = log_placements_dir

self.test_name = test_name
self.eval_name = f"{model_name}_book_{test_name}_{int(time.time())}" if test_name != "" else f"{model_name}_book_{int(time.time())}"


def _evaluate_response(self, response: str, gold_answers = None) -> int:
if gold_answers is None:
gold_answers = self.gold_answers
Expand All @@ -132,7 +132,7 @@ def get_hash(self, test_config: dict) -> str:
del new_config["results"]
del new_config["eval_name"]
return hashlib.sha256(json.dumps(new_config, sort_keys=True).encode()).hexdigest()


def evaluate(self) -> None:
np.random.seed(self.seed)
Expand Down Expand Up @@ -170,7 +170,7 @@ def evaluate(self) -> None:
print(f"Results already exist at {results_path}")
print("Skipping evaluation")
return

if self.prevent_duplicate:
for result_filename in os.listdir(self.results_dir):
with open(os.path.join(self.results_dir, result_filename), "r") as file:
Expand All @@ -180,7 +180,7 @@ def evaluate(self) -> None:
print(f"Duplicate test found with similar hash at {self.results_dir} -- TEST_HASH:", outputs["test_hash"])
print("Skipping evaluation")
return

async_tasks = []
for i in tqdm(np.linspace(self.document_depth_percent_min, self.document_depth_percent_max, self.document_depth_percent_intervals)):
needle_depth = i / 100
Expand All @@ -200,7 +200,7 @@ def evaluate(self) -> None:
retrieval_question = self.retrieval_question

placement_output = self.haystack.generate_w_needle_placement(
needle=needle,
needle=needle,
token_count_func=self.api_connector.token_count,
encoding_func=self.api_connector.encode,
decoding_func=self.api_connector.decode,
Expand All @@ -216,7 +216,7 @@ def evaluate(self) -> None:
async_tasks.append(self.api_connector.generate_response(
system_prompt=self.system_prompt,
user_prompt=filled_template,
max_tokens=self.model_config["max_tokens"],
max_tokens=self.model_config["max_tokens"],
temperature=self.model_config["temperature"],
top_p=self.model_config["top_p"]
))
Expand All @@ -238,14 +238,14 @@ def evaluate(self) -> None:
outputs["results"][i]["metric"] = self._evaluate_response(responses[i]["response"], gold_answers=[outputs["results"][i]["selected_character"]]) if "{CHAR}" in self.needle else self._evaluate_response(responses[i]["response"])
for k, v in responses[i].items():
outputs["results"][i][k] = v

# Save results by model name, haystack type, timestamp
with open(results_path, "w") as file:
json.dump(outputs, file, indent=4)

print(f"Results saved at {results_path}")




if __name__ == "__main__":
Expand All @@ -261,7 +261,6 @@ def evaluate(self) -> None:
parser.add_argument("--system_prompt", type=str, help="System prompt for the model")
parser.add_argument("--use_default_system_prompt", type=bool, default=False, help="Use default system prompt")
parser.add_argument("--task_template", type=str, help="Task template for the model")
parser.add_argument("--system_prompt", type=str, help="System prompt for the model")
parser.add_argument("--context_length", type=int, help="Context length for the needle placement")
parser.add_argument("--document_depth_percent_min", type=int, default=0, help="Minimum document depth percentage")
parser.add_argument("--document_depth_percent_max", type=int, default=100, help="Maximum document depth percentage")
Expand Down