| import unittest |
|
|
| from opencompass.openicl.icl_prompt_template import PromptTemplate |
| from opencompass.utils.prompt import PromptList |
|
|
|
|
| class TestPromptTemplate(unittest.TestCase): |
|
|
| def setUp(self) -> None: |
| self.qa_template = dict(begin=[ |
| dict(role='SYSTEM', fallback_role='HUMAN', prompt='instruct'), |
| '</E>', |
| ], |
| round=[ |
| dict(role='HUMAN', prompt='{input}'), |
| dict(role='BOT', prompt='Answer: {answer}') |
| ]) |
| self.multiround_qa_template = dict(round=[ |
| dict(role='HUMAN', prompt='{input}'), |
| dict(role='BOT', prompt='A1', end='\n'), |
| dict(role='HUMAN', prompt='Q1'), |
| dict(role='BOT', prompt='A2', end='\n\n'), |
| dict(role='HUMAN', prompt='Q2', begin='HUMAN:'), |
| dict(role='BOT', prompt='Answer: {answer}') |
| ]) |
| self.entry = {'input': 'Hello, how are you?', 'answer': 'Good.'} |
|
|
| def test_init(self): |
| template = 'Translate the following English text to French: {input}.' |
| pt = PromptTemplate(template) |
|
|
| self.assertEqual(pt.template, template) |
|
|
| def test_generate_ice_item(self): |
| |
| template = 'Translate the following English text to French: {input}.' |
| pt = PromptTemplate(template) |
| label = None |
| ice = pt.generate_ice_item(self.entry, label) |
|
|
| self.assertEqual(ice, |
| ('Translate the following English text to French: ' |
| 'Hello, how are you?.')) |
|
|
| |
| pt = PromptTemplate(self.qa_template, ice_token='</E>') |
| label = None |
| ice = pt.generate_ice_item(self.entry, label) |
|
|
| ice_target = PromptList([ |
| { |
| 'section': 'ice', |
| 'pos': 'begin' |
| }, |
| dict(role='HUMAN', prompt='Hello, how are you?'), |
| dict(role='BOT', prompt='Answer: Good.'), |
| { |
| 'section': 'ice', |
| 'pos': 'end' |
| }, |
| ]) |
| self.assertEqual(ice, ice_target) |
|
|
| |
| pt = PromptTemplate(self.multiround_qa_template, ice_token='</E>') |
| label = None |
| ice = pt.generate_ice_item(self.entry, label) |
|
|
| ice_target = PromptList([ |
| { |
| 'section': 'ice', |
| 'pos': 'begin' |
| }, |
| dict(role='HUMAN', prompt='Hello, how are you?'), |
| dict(role='BOT', prompt='A1', end='\n'), |
| dict(role='HUMAN', prompt='Q1'), |
| dict(role='BOT', prompt='A2', end='\n\n'), |
| dict(role='HUMAN', prompt='Q2', begin='HUMAN:'), |
| dict(role='BOT', prompt='Answer: Good.'), |
| { |
| 'section': 'ice', |
| 'pos': 'end' |
| }, |
| ]) |
| self.assertEqual(ice, ice_target) |
|
|
| def test_generate_label_prompt_item(self): |
| |
| template = ('</E> Translate the following English text to French: ' |
| '{input}.') |
| pt = PromptTemplate(template, ice_token='</E>') |
| ice = 'ICE' |
| label = None |
| prompt = pt.generate_label_prompt_item(self.entry, ice, label) |
|
|
| self.assertEqual( |
| prompt, ('ICE Translate the following English text to French: ' |
| 'Hello, how are you?.')) |
|
|
| ice = PromptList([ |
| { |
| 'section': 'ice', |
| 'pos': 'begin' |
| }, |
| dict(role='HUMAN', prompt='h1'), |
| dict(role='BOT', prompt='b1'), |
| { |
| 'section': 'ice', |
| 'pos': 'end' |
| }, |
| ]) |
|
|
| |
| pt = PromptTemplate(self.qa_template, ice_token='</E>') |
| label = None |
| prompt = pt.generate_label_prompt_item(self.entry, ice, label) |
| target = PromptList([ |
| { |
| 'section': 'begin', |
| 'pos': 'begin' |
| }, |
| dict(role='SYSTEM', fallback_role='HUMAN', prompt='instruct'), |
| { |
| 'section': 'ice', |
| 'pos': 'begin' |
| }, |
| dict(role='HUMAN', prompt='h1'), |
| dict(role='BOT', prompt='b1'), |
| { |
| 'section': 'ice', |
| 'pos': 'end' |
| }, |
| { |
| 'section': 'begin', |
| 'pos': 'end' |
| }, |
| { |
| 'section': 'round', |
| 'pos': 'begin' |
| }, |
| dict(role='HUMAN', prompt='Hello, how are you?'), |
| dict(role='BOT', prompt='Answer: Good.'), |
| { |
| 'section': 'round', |
| 'pos': 'end' |
| }, |
| ]) |
| self.assertEqual(prompt, target) |
|
|
| |
| pt = PromptTemplate(self.multiround_qa_template, ice_token='</E>') |
| label = None |
| prompt = pt.generate_label_prompt_item(self.entry, ice, label) |
| target = PromptList([ |
| { |
| 'section': 'round', |
| 'pos': 'begin' |
| }, |
| dict(role='HUMAN', prompt='Hello, how are you?'), |
| dict(role='BOT', prompt='A1', end='\n'), |
| dict(role='HUMAN', prompt='Q1'), |
| dict(role='BOT', prompt='A2', end='\n\n'), |
| dict(role='HUMAN', prompt='Q2', begin='HUMAN:'), |
| dict(role='BOT', prompt='Answer: Good.'), |
| { |
| 'section': 'round', |
| 'pos': 'end' |
| }, |
| ]) |
| self.assertEqual(prompt, target) |
|
|
| def test_generate_item(self): |
| |
| template = 'Translate the following English text to French: {input}.' |
| pt = PromptTemplate(template) |
| item = pt.generate_item(self.entry) |
|
|
| self.assertEqual(item, |
| ('Translate the following English text to French: ' |
| 'Hello, how are you?.')) |
|
|
| ice = PromptList([ |
| { |
| 'section': 'ice', |
| 'pos': 'begin' |
| }, |
| dict(role='HUMAN', prompt='h1'), |
| dict(role='BOT', prompt='b1'), |
| { |
| 'section': 'ice', |
| 'pos': 'end' |
| }, |
| ]) |
|
|
| |
| pt = PromptTemplate(self.qa_template, ice_token='</E>') |
| prompt = pt.generate_item(self.entry, ice_field_replace_token=ice) |
| target = PromptList([ |
| { |
| 'section': 'begin', |
| 'pos': 'begin' |
| }, |
| dict(role='SYSTEM', fallback_role='HUMAN', prompt='instruct'), |
| { |
| 'section': 'ice', |
| 'pos': 'begin' |
| }, |
| dict(role='HUMAN', prompt='h1'), |
| dict(role='BOT', prompt='b1'), |
| { |
| 'section': 'ice', |
| 'pos': 'end' |
| }, |
| { |
| 'section': 'begin', |
| 'pos': 'end' |
| }, |
| { |
| 'section': 'round', |
| 'pos': 'begin' |
| }, |
| dict(role='HUMAN', prompt='Hello, how are you?'), |
| dict(role='BOT', prompt='Answer: Good.'), |
| { |
| 'section': 'round', |
| 'pos': 'end' |
| }, |
| ]) |
| self.assertEqual(prompt, target) |
|
|
| pt = PromptTemplate(self.multiround_qa_template, ice_token='</E>') |
| prompt = pt.generate_item(self.entry, ice_field_replace_token=ice) |
| target = PromptList([ |
| { |
| 'section': 'round', |
| 'pos': 'begin' |
| }, |
| dict(role='HUMAN', prompt='Hello, how are you?'), |
| dict(role='BOT', prompt='A1', end='\n'), |
| dict(role='HUMAN', prompt='Q1'), |
| dict(role='BOT', prompt='A2', end='\n\n'), |
| dict(role='HUMAN', prompt='Q2', begin='HUMAN:'), |
| dict(role='BOT', prompt='Answer: Good.'), |
| { |
| 'section': 'round', |
| 'pos': 'end' |
| }, |
| ]) |
| self.assertEqual(prompt, target) |
|
|