intellij-plugin/features/ai-hints-python/testSrc/com/jetbrains/edu/aiHints/python/PyFunctionDiffReducerTest.kt (1,363 lines of code) (raw):
package com.jetbrains.edu.aiHints.python
import com.intellij.psi.PsiElement
import com.intellij.psi.PsiFile
import com.intellij.psi.PsiFileFactory
import com.jetbrains.edu.aiHints.core.EduAIHintsProcessor
import com.jetbrains.edu.learning.EduTestCase
import com.jetbrains.python.PythonLanguage
import org.junit.Test
class PyFunctionDiffReducerTest : EduTestCase() {
private val pyFunctionDiffReducer by lazy {
EduAIHintsProcessor.forCourse(getCourse())?.getFunctionDiffReducer()
}
@Test
fun `test simple function`() = assertCodeHint(
functionName = "foo",
currentCode = """
def foo():
# TODO
""",
codeHint = """
def foo():
csv = pd.read_csv("file.csv")
return csv['col_name']
""",
expectedResult = """
def foo():
csv = pd.read_csv("file.csv")
return csv['col_name']
"""
)
@Test
fun `test add for loop to empty function`() = assertCodeHint(
functionName = "foo",
currentCode = """
def foo(lst):
pass
""",
codeHint = """
def foo(lst):
for l in lst:
print(l)
""",
expectedResult = """
def foo(lst):
for l in lst:
print(l)
"""
)
@Test
fun `test add while loop to empty function (short)`() = assertCodeHint(
functionName = "foo",
currentCode = """
def foo(lst):
pass
""",
codeHint = """
def foo(lst):
while length(lst) > 0:
lst = lst[:-1]
""",
expectedResult = """
def foo(lst):
while length(lst) > 0:
lst = lst[:-1]
"""
)
@Test
fun `test add while loop to empty function (long)`() = assertCodeHint(
functionName = "foo",
currentCode = """
def foo(lst):
pass
""",
codeHint = """
def foo(lst):
while length(lst) > 0:
lst = lst[:-1]
a += b
c = 42
lst.append(c)
""",
expectedResult = """
def foo(lst):
while length(lst) > 0:
pass
"""
)
@Test
fun `test add if statement to empty function`() = assertCodeHint(
functionName = "foo",
currentCode = """
def foo(lst):
pass
""",
codeHint = """
def foo(lst):
if length(lst) > 0:
print("List is not empty")
""",
expectedResult = """
def foo(lst):
if length(lst) > 0:
print("List is not empty")
"""
)
@Test
fun `test add return statement sum`() = assertCodeHint(
functionName = "sum",
currentCode = """
def sum(a, b):
c = a + b
""",
codeHint = """
def sum(a, b):
c = a + b
return c
""",
expectedResult = """
def sum(a, b):
c = a + b
return c
"""
)
@Test
fun `test add return statement sum_typed`() = assertCodeHint(
functionName = "sum_typed",
currentCode = """
def sum_typed(a: int, b: int):
c = a + b
""",
codeHint = """
def sum_typed(a: int, b: int):
c = a + b
return c
""",
expectedResult = """
def sum_typed(a: int, b: int):
c = a + b
return c
"""
)
@Test
fun `test add return statement create_series`() = assertCodeHint(
functionName = "create_series",
currentCode = """
def create_series(dict_: dict):
# TODO
""",
codeHint = """
def create_series(dict_: dict):
return pd.Series(dict_)
""",
expectedResult = """
def create_series(dict_: dict):
return pd.Series(dict_)
"""
)
@Test
fun `test add one statement from a solution`() = assertCodeHint(
functionName = "correct_inconsistency",
currentCode = """
def correct_inconsistency(df):
# Correct inconsistencies in the 'Height' column by converting all values to float
""",
codeHint = """
def correct_inconsistency(df):
df_final = df.copy()
df_final['Height'] = df_final['Height'].astype(float)
return df_final
""",
expectedResult = """
def correct_inconsistency(df):
df_final = df.copy()
"""
)
@Test
fun `test add a missing while statement`() = assertCodeHint(
functionName = "while_part",
currentCode = """
def while_part():
a = 52
b = 12
return a + b
""",
codeHint = """
def while_part():
a = 52
b = 12
while b < a:
print(a)
b = b + 10
return a + b
""",
expectedResult = """
def while_part():
a = 52
b = 12
while b < a:
pass
"""
)
@Test
fun `test empty spaces in the parameter list`() = assertCodeHint(
functionName = "parameter_list_spaces",
currentCode = """
def parameter_list_spaces(a , b):
a += b
""",
codeHint = """
def parameter_list_spaces( a , b ):
a += b
print(a)
print(b)
return a + b + 42
""",
expectedResult = """
def parameter_list_spaces(a , b):
a += b
print(a)
"""
)
@Test
fun `test complete parameter list`() = assertCodeHint(
functionName = "param_change",
currentCode = """
def param_change(a, b):
return a + b
""",
codeHint = """
def param_change(a, b, c=0):
return a + b + c
""",
expectedResult = """
def param_change(a, b, c=0):
return a + b + c
"""
)
@Test
fun `test complex parameters`() = assertCodeHint(
functionName = "complex_params",
currentCode = """
def complex_params(a, b=10, *args):
return sum([a, b] + list(args))
""",
codeHint = """
def complex_params(a, b=10, *args, **kwargs):
return sum([a, b] + list(args)) + sum(kwargs.values())
""",
expectedResult = """
def complex_params(a, b=10, *args, **kwargs):
return sum([a, b] + list(args)) + sum(kwargs.values())
"""
)
@Test
fun `test change parameter types`() = assertCodeHint(
functionName = "param_type_change",
currentCode = """
def param_type_change(a: int, b: str):
return str(a) + b
""",
codeHint = """
def param_type_change(a: float, b: str):
return str(a) + b
""",
expectedResult = """
def param_type_change(a: float, b: str):
return str(a) + b
"""
)
@Test
fun `test complete the return type`() = assertCodeHint(
functionName = "return_type_change",
currentCode = """
def return_type_change(a: int, b: int) -> int:
return a + b
""",
codeHint = """
def return_type_change(a: int, b: int) -> float:
return a + b
""",
expectedResult = """
def return_type_change(a: int, b: int) -> float:
return a + b
"""
)
@Test
fun `test empty spaces in the return type`() = assertCodeHint(
functionName = "foo",
currentCode = """
def foo() -> int :
# TODO
""",
codeHint = """
def foo() -> int:
csv = pd.read_csv("file.csv")
csv['col_name'] = csv['col_name'].apply(lambda x: x.strip())
print(csv['col_name'])
return csv['col_name']
""",
expectedResult = """
def foo() -> int :
csv = pd.read_csv("file.csv")
"""
)
@Test
fun `test removing the return type annotation`() = assertCodeHint(
functionName = "remove_return_type",
currentCode = """
def remove_return_type(a: int, b: int) -> int:
return a + b
""",
codeHint = """
def remove_return_type(a: int, b: int):
return a + b
""",
expectedResult = """
def remove_return_type(a: int, b: int):
return a + b
"""
)
@Test
fun `test add for loop with pass`() = assertCodeHint(
functionName = "for_loop",
currentCode = """
def for_loop(items):
result = 0
return result
""",
codeHint = """
def for_loop(items):
result = 0
for item in items:
result += item
return result
""",
expectedResult = """
def for_loop(items):
result = 0
for item in items:
pass
"""
)
@Test
fun `test add if with pass`() = assertCodeHint(
functionName = "if_statement",
currentCode = """
def if_statement(value):
result = 0
return result
""",
codeHint = """
def if_statement(value):
result = 0
if value > 0:
result = value
return result
""",
expectedResult = """
def if_statement(value):
result = 0
if value > 0:
pass
"""
)
@Test
fun `test add missing if statement`() = assertCodeHint(
functionName = "replace_statement_type",
currentCode = """
def replace_statement_type(value):
result = 0
result += value
return result
""",
codeHint = """
def replace_statement_type(value):
result = 0
if value > 0:
result = value
return result
""",
expectedResult = """
def replace_statement_type(value):
result = 0
if value > 0:
pass
return result
"""
)
@Test
fun `test nested structures`() = assertCodeHint(
functionName = "nested_structures",
currentCode = """
def nested_structures(items):
result = 0
return result
""",
codeHint = """
def nested_structures(items):
result = 0
for item in items:
if item > 0:
while item > 0:
result += 1
item -= 1
return result
""",
expectedResult = """
def nested_structures(items):
result = 0
for item in items:
pass
"""
)
@Test
fun `test nested structures 2`() = assertCodeHint(
functionName = "nested_structures",
currentCode = """
def nested_structures(items):
result = 0
for item in items:
pass
""",
codeHint = """
def nested_structures(items):
result = 0
for item in items:
if item > 0:
while item > 0:
result += 1
item -= 1
return result
""",
expectedResult = """
def nested_structures(items):
result = 0
for item in items:
if item > 0:
pass
"""
)
@Test
fun `test nested structures 3`() = assertCodeHint(
functionName = "nested_structures",
currentCode = """
def nested_structures(items):
result = 0
for item in items:
if item > 0:
pass
""",
codeHint = """
def nested_structures(items):
result = 0
for item in items:
if item > 0:
while item > 0:
result += 1
item -= 1
return result
""",
expectedResult = """
def nested_structures(items):
result = 0
for item in items:
if item > 0:
while item > 0:
pass
"""
)
@Test
fun `test nested structures 4`() = assertCodeHint(
functionName = "nested_structures",
currentCode = """
def nested_structures(items):
result = 0
for item in items:
if item > 0:
pass
""",
codeHint = """
def nested_structures(items):
result = 0
for item in items:
if item > 0:
while item > 0:
result += 1
item -= 1
return result
""",
expectedResult = """
def nested_structures(items):
result = 0
for item in items:
if item > 0:
while item > 0:
pass
"""
)
@Test
fun `test nested structures 5`() = assertCodeHint(
functionName = "nested_structures",
currentCode = """
def nested_structures(items):
result = 0
for item in items:
if item > 0:
while item > 0:
result += 1
item -= 1
""",
codeHint = """
def nested_structures(items):
result = 0
for item in items:
if item > 0:
while item > 0:
result += 1
item -= 1
return result
""",
expectedResult = """
def nested_structures(items):
result = 0
for item in items:
if item > 0:
while item > 0:
result += 1
item -= 1
return result
"""
)
@Test
fun `test check_number full path`() = assertCodeHint(
functionName = "check_number",
currentCode = """
def check_number(num):
pass
""",
codeHint = """
def check_number(num):
for i in range(0, 10):
for j in range(0, 10):
if num > 0:
print("Positive")
elif num < 0:
print("Nogative")
elif num == 0:
print("WA")
""",
expectedResult = """
def check_number(num):
for i in range(0, 10):
pass
"""
)
@Test
fun `test check_number full path 2`() = assertCodeHint(
functionName = "check_number",
currentCode = """
def check_number(num):
for i in range(0, 10):
pass
""",
codeHint = """
def check_number(num):
for i in range(0, 10):
for j in range(0, 10):
if num > 0:
print("Positive")
elif num < 0:
print("Nogative")
elif num == 0:
print("WA")
""",
expectedResult = """
def check_number(num):
for i in range(0, 10):
for j in range(0, 10):
pass
"""
)
@Test
fun `test check_number full path 3`() = assertCodeHint(
functionName = "check_number",
currentCode = """
def check_number(num):
for i in range(0, 10):
for j in range(0, 10):
pass
""",
codeHint = """
def check_number(num):
for i in range(0, 10):
for j in range(0, 10):
if num > 0:
print("Positive")
elif num < 0:
print("Nogative")
elif num == 0:
print("WA")
""",
expectedResult = """
def check_number(num):
for i in range(0, 10):
for j in range(0, 10):
if num > 0:
pass
"""
)
@Test
fun `test check_number full path 4`() = assertCodeHint(
functionName = "check_number",
currentCode = """
def check_number(num):
for i in range(0, 10):
for j in range(0, 10):
if num > 0:
pass
""",
codeHint = """
def check_number(num):
for i in range(0, 10):
for j in range(0, 10):
if num > 0:
print("Positive")
elif num < 0:
print("Nogative")
elif num == 0:
print("WA")
""",
expectedResult = """
def check_number(num):
for i in range(0, 10):
for j in range(0, 10):
if num > 0:
print("Positive")
"""
)
@Test
fun `test check_number full path 5`() = assertCodeHint(
functionName = "check_number",
currentCode = """
def check_number(num):
for i in range(0, 10):
for j in range(0, 10):
if num > 0:
print("Positive")
elif num < 0:
pass
""",
codeHint = """
def check_number(num):
for i in range(0, 10):
for j in range(0, 10):
if num > 0:
print("Positive")
elif num < 0:
print("Nogative")
elif num == 0:
print("WA")
""",
expectedResult = """
def check_number(num):
for i in range(0, 10):
for j in range(0, 10):
if num > 0:
print("Positive")
elif num < 0:
print("Nogative")
"""
)
@Test
fun `test check_number full path 6`() = assertCodeHint(
functionName = "check_number",
currentCode = """
def check_number(num):
for i in range(0, 10):
for j in range(0, 10):
if num > 0:
print("Positive")
elif num < 0:
print("Nogative")
elif num == 0:
pass
""",
codeHint = """
def check_number(num):
for i in range(0, 10):
for j in range(0, 10):
if num > 0:
print("Positive")
elif num < 0:
print("Nogative")
elif num == 0:
print("WA")
""",
expectedResult = """
def check_number(num):
for i in range(0, 10):
for j in range(0, 10):
if num > 0:
print("Positive")
elif num < 0:
print("Nogative")
elif num == 0:
print("WA")
"""
)
@Test
fun `test check_number full path 7`() = assertCodeHint(
functionName = "check_number",
currentCode = """
def check_number(num):
for i in range(0, 10):
for j in range(0, 10):
if num > 0:
print("Positive")
elif num < 0:
print("Nogative")
elif num == 0:
print("WA")
""",
codeHint = """
def check_number(num):
for i in range(0, 10):
for j in range(0, 10):
if num > 0:
print("Positive")
elif num < 0:
print("Negative")
elif num == 0:
print("Zero")
""",
expectedResult = """
def check_number(num):
for i in range(0, 10):
for j in range(0, 10):
if num > 0:
print("Positive")
elif num < 0:
print("Negative")
elif num == 0:
print("WA")
"""
)
@Test
fun `test check_number full path 8`() = assertCodeHint(
functionName = "check_number",
currentCode = """
def check_number(num):
for i in range(0, 10):
for j in range(0, 10):
if num > 0:
print("Positive")
elif num < 0:
print("Negative")
elif num == 0:
print("WA")
""",
codeHint = """
def check_number(num):
for i in range(0, 10):
for j in range(0, 10):
if num > 0:
print("Positive")
elif num < 0:
print("Negative")
elif num == 0:
print("Zero")
""",
expectedResult = """
def check_number(num):
for i in range(0, 10):
for j in range(0, 10):
if num > 0:
print("Positive")
elif num < 0:
print("Negative")
elif num == 0:
print("Zero")
"""
)
@Test
fun `test function with pass body`() = assertCodeHint(
functionName = "empty_function",
currentCode = """
def empty_function():
pass
""",
codeHint = """
def empty_function():
result = 42
return result
""",
expectedResult = """
def empty_function():
result = 42
return result
"""
)
@Test
fun `test function with comments`() = assertCodeHint(
functionName = "comments_only",
currentCode = """
def comments_only():
# This function does nothing
# It just has comments
""",
codeHint = """
def comments_only():
# This function does nothing
# It just has comments
result = "Hello, World!"
return result
""",
expectedResult = """
def comments_only():
result = "Hello, World!"
"""
)
@Test
fun `test function with spaces comments`() = assertCodeHint(
functionName = "spaces",
currentCode = """
def spaces():
result = "Hello, World!"
""",
codeHint = """
def spaces():
result = "Hello, World!"
print(result)
print("\n")
return result
""",
expectedResult = """
def spaces():
result = "Hello, World!"
print(result)
"""
)
@Test
fun `test new condition in the if`() = assertCodeHint(
functionName = "modify_condition_if",
currentCode = """
def modify_condition_if(value):
if value > 0:
return "positive"
return "non-positive"
""",
codeHint = """
def modify_condition_if(value):
if value >= 0:
return "non-negative"
return "negative"
""",
expectedResult = """
def modify_condition_if(value):
if value >= 0:
return "positive"
return "non-positive"
"""
)
@Test
fun `test new condition in the while`() = assertCodeHint(
functionName = "modify_condition_while",
currentCode = """
def modify_condition_while(value):
while value > 0:
value-=1
networkCall()
""",
codeHint = """
def modify_condition_while(value):
while value >= 0:
value-=1
return networkCall()
""",
expectedResult = """
def modify_condition_while(value):
while value >= 0:
value-=1
networkCall()
"""
)
@Test
fun `test new condition in the for`() = assertCodeHint(
functionName = "modify_condition_for",
currentCode = """
def modify_condition_for(value):
for i in range(0, 52):
pass
""",
codeHint = """
def modify_condition_for(value):
for i in range(0, 104):
data = read_data(i)
return data[i]
""",
expectedResult = """
def modify_condition_for(value):
for i in range(0, 104):
pass
"""
)
@Test
fun `test update one line`() = assertCodeHint(
functionName = "list_comprehension",
currentCode = """
def list_comprehension(items):
result = []
return result
""",
codeHint = """
def list_comprehension(items):
result = [x * 2 for x in items if x > 0]
return result
""",
expectedResult = """
def list_comprehension(items):
result = [x * 2 for x in items if x > 0]
return result
"""
)
@Test
fun `test for part with spaces`() = assertCodeHint(
functionName = "for_part",
currentCode = """
def for_part(items):
result = []
for i in range(0, 9):
a = i * i
return result
""",
codeHint = """
def for_part(items):
result = []
for i in range(0, 9):
a = i * i
result.append(a)
return result
""",
expectedResult = """
def for_part(items):
result = []
for i in range(0, 9):
a = i * i
result.append(a)
return result
"""
)
@Test
fun `test while part with spaces`() = assertCodeHint(
functionName = "while_part_with_spaces",
currentCode = """
def while_part_with_spaces(items):
result = []
while i > 10:
pass
return result
""",
codeHint = """
def while_part_with_spaces(items):
result = []
while i > 10:
result.append(i)
return result
""",
expectedResult = """
def while_part_with_spaces(items):
result = []
while i > 10:
result.append(i)
return result
"""
)
@Test
fun `test while part with spaces 2`() = assertCodeHint(
functionName = "count_down",
currentCode = """
def count_down(start_number):
current = start_number
total_sum = 0
while current > 0:
print (f"Counting: {current}")
total_sum += current
print ("Countdown complete!")
return total_sum
""",
codeHint = """
def count_down(start_number):
current = start_number
total_sum = 0
while current > 0:
print(f"Counting: {current}")
total_sum += current
current -= 1
print("Countdown complete!")
return total_sum
""",
expectedResult = """
def count_down(start_number):
current = start_number
total_sum = 0
while current > 0:
print (f"Counting: {current}")
total_sum += current
current -= 1
print ("Countdown complete!")
return total_sum
"""
)
@Test
fun `test for loop with else`() = assertCodeHint(
functionName = "for_loop_with_else",
currentCode = """
def for_loop_with_else(items):
for i in range(0, 10):
print(i)
else:
print("Irrelevant line")
""",
codeHint = """
def for_loop_with_else(items):
for i in range(0, 10):
print(i)
else:
print("Done!")
""",
expectedResult = """
def for_loop_with_else(items):
for i in range(0, 10):
print(i)
else:
print("Done!")
""",
)
@Test
fun `test adding else part to function with for loop`() = assertCodeHint(
functionName = "for_loop_with_else",
currentCode = """
def for_loop_with_else(items):
for i in range(0, 10):
print(i)
""",
codeHint = """
def for_loop_with_else(items):
for i in range(0, 10):
print(i)
else:
print("Done!")
""",
expectedResult = """
def for_loop_with_else(items):
for i in range(0, 10):
print(i)
else:
pass
""",
)
@Test
fun `test while loop with else`() = assertCodeHint(
functionName = "while_loop_with_else",
currentCode = """
def while_loop_with_else(count):
i = 0
while i < count:
print(i)
i += 1
else:
print("Irrelevant line")
""",
codeHint = """
def while_loop_with_else(count):
i = 0
while i < count:
print(i)
i += 1
else:
print("Loop completed normally")
""",
expectedResult = """
def while_loop_with_else(count):
i = 0
while i < count:
print(i)
i += 1
else:
print("Loop completed normally")
""",
)
@Test
fun `test adding else part to function with while loop`() = assertCodeHint(
functionName = "while_loop_with_else",
currentCode = """
def while_loop_with_else(count):
i = 0
while i < count:
print(i)
i += 1
""",
codeHint = """
def while_loop_with_else(count):
i = 0
while i < count:
print(i)
i += 1
else:
print("Loop completed normally")
""",
expectedResult = """
def while_loop_with_else(count):
i = 0
while i < count:
print(i)
i += 1
else:
pass
""",
)
@Test
fun `test if-else structure modification`() = assertCodeHint(
functionName = "check_number",
currentCode = """
def check_number(num):
if num > 0:
print("Positive")
else:
print("Irrelevant line")
""",
codeHint = """
def check_number(num):
if num > 0:
print("Positive")
else:
print("Zero")
""",
expectedResult = """
def check_number(num):
if num > 0:
print("Positive")
else:
print("Zero")
""",
)
@Test
fun `test adding else part to function with if statement`() = assertCodeHint(
functionName = "check_number",
currentCode = """
def check_number(num):
if num > 0:
print("Positive")
""",
codeHint = """
def check_number(num):
if num > 0:
print("Positive")
else:
print("Zero")
""",
expectedResult = """
def check_number(num):
if num > 0:
print("Positive")
else:
pass
""",
)
@Test
fun `test if-elif structure modification`() = assertCodeHint(
functionName = "check_number",
currentCode = """
def check_number(num):
if num > 0:
print("Positive")
""",
codeHint = """
def check_number(num):
if num > 0:
print("Positive")
elif num < 0:
print("Negative")
""",
expectedResult = """
def check_number(num):
if num > 0:
print("Positive")
elif num < 0:
pass
""",
)
@Test
fun `test adding elif part to function with if statement`() = assertCodeHint(
functionName = "check_number",
currentCode = """
def check_number(num):
if num > 0:
print("Positive")
elif num < 0:
print("something else")
""",
codeHint = """
def check_number(num):
if num > 0:
print("Positive")
elif num < 0:
print("Negative")
""",
expectedResult = """
def check_number(num):
if num > 0:
print("Positive")
elif num < 0:
print("Negative")
""",
)
@Test
fun `test adding elif part to function with another elif statement`() = assertCodeHint(
functionName = "check_number",
currentCode = """
def check_number(num):
if num > 0:
print("Positive")
elif num < 0:
print("Negative")
""",
codeHint = """
def check_number(num):
if num > 0:
print("Positive")
elif num < 0:
print("Negative")
elif num == 0:
print("Zero")
""",
expectedResult = """
def check_number(num):
if num > 0:
print("Positive")
elif num < 0:
print("Negative")
elif num == 0:
pass
""",
)
@Test
fun `test removing elif parts when there are no in the CodeHint`() = assertCodeHint(
functionName = "check_number",
currentCode = """
def check_number(num):
if num > 0:
print("Positive")
elif num < 0:
print("Negative")
elif num == 0:
print("Zero")
""",
codeHint = """
def check_number(num):
if num > 0:
print("Positive")
return num
""",
expectedResult = """
def check_number(num):
if num > 0:
print("Positive")
""",
)
@Test
fun `test if-elif content modification in elif parts`() = assertCodeHint(
functionName = "check_number",
currentCode = """
def check_number(num):
if num > 0:
print("Positive")
elif num < 0:
print("Nogative")
elif num == 0:
print("WA")
""",
codeHint = """
def check_number(num):
if num > 0:
print("Positive")
elif num < 0:
print("Negative")
elif num == 0:
print("Zero")
""",
expectedResult = """
def check_number(num):
if num > 0:
print("Positive")
elif num < 0:
print("Negative")
elif num == 0:
print("WA")
""",
)
/**
* Tests for adding a new function to the student code.
* The expected result here is the function that will be suggested to add.
*/
@Test
fun `test adding small function`() = assertCodeHint(
functionName = "new_function",
currentCode = """
def foo():
return 42
""",
codeHint = """
def foo():
return 42
def new_function():
# This function returns string
return "Hello, World!"
""",
expectedResult = """
def new_function():
# This function returns string
return "Hello, World!"
""",
)
@Test
fun `test adding big function`() = assertCodeHint(
functionName = "new_function",
currentCode = """
def foo():
return 42
""",
codeHint = """
def foo():
return 42
def new_function():
for (l in lst):
print(l)
l++
""",
expectedResult = """
def new_function():
pass
""",
)
private fun assertCodeHint(
functionName: String,
currentCode: String,
codeHint: String,
expectedResult: String
) {
// when
courseWithFiles(language = PythonLanguage.INSTANCE) {
lesson(PY_LESSON) {
eduTask(PY_TASK) {
pythonTaskFile(PY_TASK_FILE, currentCode.trimIndent())
}
}
}
val current = getPsiFile(project, PY_LESSON, PY_TASK, PY_TASK_FILE)
val codeHint = PsiFileFactory.getInstance(project).createFileFromText("codeHint.py", PythonLanguage.INSTANCE, codeHint.trimIndent())
val functionFromCode = getFunctionPsiWithName(current, functionName)
val functionFromCodeHint = getFunctionPsiWithName(codeHint, functionName) ?: error("PSI File for CodeHint is null")
// then
val resultPsiElement = pyFunctionDiffReducer?.reduceDiffFunctions(functionFromCode, functionFromCodeHint)
// verify
assertEquals(expectedResult.trimIndent(), resultPsiElement?.text)
}
private fun getFunctionPsiWithName(codePsiFile: PsiFile, functionName: String): PsiElement? {
return EduAIHintsProcessor.forCourse(getCourse())?.getFunctionSignatureManager()?.getFunctionBySignature(codePsiFile, functionName)
}
}