Spaces:
Runtime error
Runtime error
| import pytest | |
| from llama2_wrapper.model import get_prompt_for_dialog | |
| class TestClassGetPromptForDialog: | |
| from llama2_wrapper.types import Message | |
| dialog = [] | |
| message1 = Message( | |
| role="system", | |
| content="You are a helpful, respectful and honest assistant. ", | |
| ) | |
| message2 = Message( | |
| role="user", | |
| content="Hi do you know Pytorch?", | |
| ) | |
| dialog.append(message1) | |
| dialog.append(message2) | |
| dialog2 = [] | |
| dialog2.append(message1) | |
| dialog2.append(message2) | |
| message3 = Message( | |
| role="assistant", | |
| content="Yes I know Pytorch. ", | |
| ) | |
| message4 = Message( | |
| role="user", | |
| content="Can you write a CNN in Pytorch?", | |
| ) | |
| dialog2.append(message3) | |
| dialog2.append(message4) | |
| dialog3 = [] | |
| dialog3.append(message3) | |
| dialog3.append(message4) | |
| dialog3.append(message3) | |
| dialog3.append(message4) | |
| message5 = Message( | |
| role="assistant", | |
| content="Yes I can write a CNN in Pytorch.", | |
| ) | |
| dialog3.append(message5) | |
| def test_dialog1(self): | |
| prompt = get_prompt_for_dialog(self.dialog) | |
| # print(prompt) | |
| result = """[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. \n<</SYS>>\n\nHi do you know Pytorch? [/INST]""" | |
| assert prompt == result | |
| def test_dialog2(self): | |
| prompt = get_prompt_for_dialog(self.dialog2) | |
| # print(prompt) | |
| result = """[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. \n<</SYS>>\n\nHi do you know Pytorch? [/INST] Yes I know Pytorch. [INST] Can you write a CNN in Pytorch? [/INST]""" | |
| assert prompt == result | |
| def test_dialog3(self): | |
| with pytest.raises(AssertionError): | |
| prompt = get_prompt_for_dialog(self.dialog3) | |