generated from salesforce/oss-template
-
Notifications
You must be signed in to change notification settings - Fork 1
/
resample_baseline.py
142 lines (128 loc) · 5.42 KB
/
resample_baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# LICENSE HEADER MANAGED BY add-license-header
#
# /*
# * Copyright (c) 2023, Salesforce, Inc.
# * SPDX-License-Identifier: Apache-2
# *
# * Licensed under the Apache License, Version 2.0 (the "License");
# * you may not use this file except in compliance with the License.
# * You may obtain a copy of the License at
# *
# * http://www.apache.org/licenses/LICENSE-2.0
# *
# * Unless required by applicable law or agreed to in writing, software
# * distributed under the License is distributed on an "AS IS" BASIS,
# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# * See the License for the specific language governing permissions and
# * limitations under the License.
# */
#
from utils import enumerate_resume, make_printv, write_jsonl, resume_success_count
from executors import executor_factory
import sys
from common import gen_test_eval
from generators import generator_factory, model_factory
from typing import List, Dict, Tuple, Any
import math
import sys
from collections import Counter
sys.set_int_max_str_digits(100000) # Increase the limit to 10000 digits
class Node:
def __init__(self, solution: str, parent=None, context="", depth=0):
self.solution = solution
self.parent = parent
self.children = []
self.value = 0
self.visits = 0
self.context = ""
self.depth = depth
self.reflection = ""
self.test_feedback = ""
def uct(self, exploration_weight=1.0):
if self.visits == 0:
# return float('inf')
return self.value
return (self.value / self.visits) + exploration_weight * math.sqrt(math.log(self.parent.visits) / self.visits)
def best_child(self):
if not self.children: # Check if children list is empty
return None
return max(self.children, key=lambda child: child.uct())
def best_child_value(self):
if not self.children: # Check if children list is empty
return None
return max(self.children, key=lambda child: child.value)
def sort_children_by_value(self):
self.children.sort(key=lambda x: x.value)
def update(self, reward: float):
self.visits += 1
self.value += reward
def resample(
dataset: List[dict],
model_name: str,
language: str,
max_iters: int,
log_path: str,
verbose: bool,
is_leetcode: bool = False,
Codecontests: bool = False
) -> None:
if Codecontests:
exe = executor_factory("code_contests")
else: exe = executor_factory(language, is_leet=is_leetcode)
pass_problem_subset = []
gen = generator_factory(language)
model = model_factory(model_name)
print_v = make_printv(verbose)
num_items = len(dataset)
num_success, weak_success = 0, 0 # Counter for successful solutions
passed_at_sample, solve_or_not = [], []
for idx, item in enumerate(dataset):
print("STARTING EXAMPLE", idx)
tests_i = item["given_tests"]
if Codecontests:
item["entry_point"] = ""
else:
tests_i = [test for test in tests_i if item['entry_point'] in test and 'assert False' not in test]
root = Node("")
stack = [root] # implementations
is_solved, is_weaker_solved = False, False
num_try = 0
for i in range(max_iters):
cur_func_impl = None
while cur_func_impl is None:
cur_func_impl = gen.func_impl(item["prompt"], model, "simple", temperature=1.0)
stack.append(Node(cur_func_impl))
stack[0].children.append(stack[-1])
is_passing, feedback, reward = gen_test_eval(exe, cur_func_impl, tests_i, prev=item["prev"])
num_try += 1
stack[-1].update(reward)
stack[-1].test_feedback = feedback
if is_passing:
is_solved = exe.evaluate(
item["entry_point"], cur_func_impl, item["test"], timeout=1, prev=item["prev"]) # early exit
if "weaker_test" in item.keys():
is_weaker_solved = exe.evaluate(
item["entry_point"], cur_func_impl, item["weaker_test"], timeout=1, prev=item["prev"])
break
# Exit when passed public test cases.
if is_passing:
if is_solved:
num_success += int(is_solved)
passed_at_sample.append(num_try)
if "difficulty" in item.keys(): pass_problem_subset.append(item["difficulty"])
else: print("sad, passed but no solve.")
if is_weaker_solved:
weak_success += int(is_weaker_solved)
item["acc"] = round(num_success / (idx + 1), 3)
item["weak_acc"] = round(weak_success / (idx + 1), 3)
write_jsonl(log_path, [item], append=True)
print_v(f'completed {idx + 1}/{num_items}: acc = {round(num_success / (idx + 1), 3)}, weak_acc={round(weak_success / (idx + 1), 3)}')
continue # early stop on this case if passsed
print("_______________________________")
print(passed_at_sample)
print(sorted(passed_at_sample))
print(len(passed_at_sample))
print(Counter(passed_at_sample))
print(Counter(pass_problem_subset))
# write_jsonl(log_path, [item], append=True)
print_v(f'completed {idx + 1}/{num_items}: acc = {round(num_success / (idx + 1), 3)}, weak_acc={round(weak_success / (idx + 1), 3)}')