-
David Berenstein authored
# Description I added a very rough outline of my ideation behind `prepare_for_training` with the new `FeedbackDataset`. As discussed there are 3 complexities: - How to resolve annotator alignment? - How to resolve optional fields, which have not been filled out? e.g., "Please provide a correction for prompt 1?". - How handle potential concatenation of fields? To make it modular I created a step-wise approach. 1. `Pydantic` Models that map and verify data fields, like so. By doing this we keep the flexibility to allow for other tasks like TextClassification and this ensures we can directly use `datasets.field` and `dataset.questions` for defining training. We could also use the `name` values from the fields/questions, but this might be more error prone. 2. `get_relevant_data_for_training()` in `List[dict]` format with all relevant fields from the Pydantic model. **annotator alignment issue**. For now I opted for choosing the first non-zero value. 3. Forward the `List[dict]` to a similar flow we previously had. 4. Also add `dataset.unify_responses(question, Enum(strategy))`-method 5. Added `*QuestionUnifcation` to schemas to hold logic surrounding unifying multiplier responses 6. Added `client.feedback.training` 7. Added`TrainingDataFor*` to hold logic surrounding `prepare_for_training`-methods per task 8. Added inheritance for ArgillaTrainer ```python import argilla as rg from argilla import ( FeedbackRecord, LabelQuestion, LabelQuestionUnification, MultiLabelQuestion, TrainingDataForTextClassification, ArgillaTrainer ) dataset = rg.FeedbackDataset( guidelines="Add some guidelines for the annotation team here.", fields=[ rg.TextField(name="text", title="Human prompt"), ], questions =[ LabelQuestion( name="relevant", title="Is the response relevant for the given prompt?", labels=["yes","no"], required=True, visible_labels=None ), MultiLabelQuestion( name="content_class", title="Does the response include any of the following?", description="Select all that apply", labels={"hate": "Hate Speech" , "sexual": "Sexual content", "violent": "Violent content", "pii": "Personal information", "untruthful": "Untruthful info", "not_english": "Not English", "inappropriate": "Inappropriate content"}, required=False, visible_labels=4 ), ] ) dataset.add_records( records=[ FeedbackRecord( fields={"text": "What is your favorite color?"}, responses=[{"values": {"relevant": {"value": "yes"}, "content_class": {"value": ["hate"]}}}] ), FeedbackRecord( fields={"text": "What do you think about the new iPhone?"}, responses=[{"values": {"relevant": {"value": "no"}, "content_class": {"value": ["hate"]}}}] ), FeedbackRecord( fields={"text": "What is your feeling about the technology?"}, responses=[{"values": {"relevant": {"value": "yes"}, "content_class": {"value": ["sexual"]}}}, {"values": {"relevant": {"value": "no"}, "content_class": {"value": ["hate", "sexual"]}}}, {"values": {"relevant": {"value": "yes"}, "content_class": {"value": ["hate", "sexual"]}}}] ), FeedbackRecord( fields={"text": "Jesus Christ!"}, responses=[{"values": {"relevant": {"value": "no"}, "content_class": {"value": ["sexual"]}}}, {"values": {"relevant": {"value": "no"}, "content_class": {"value": ["hate"]}}}] ) ] ) # print(dataset.question_by_name("relevant").__all_labels__) label = LabelQuestionUnification(question=dataset.question_by_name("relevant"), strategy="majority") training_data = TrainingDataForTextClassification(text=dataset.field_by_name("text"), label=label) for framework in ["spacy", "transformers", "openai", "spark-nlp"]: formatted_data = dataset.prepare_for_training(framework, training_data, fetch_records=False, train_size=0.8) print(formatted_data) trainer = ArgillaTrainer( dataset=dataset, training_task_mapping=training_task_mapping, framework="setfit", fetch_records=False ) trainer.train("test") ``` Closes #2954 Closes #3184 Closes #3152 **Type of change** - [X] New feature (non-breaking change which adds functionality) - [X] Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** - [ ] Test A - [ ] Test B **Checklist** - [ ] I have merged the original branch into my forked branch - [ ] I added relevant documentation - [ ] follows the style guidelines of this project - [ ] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [ ] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/ ) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Alvaro Bartolome <alvaro@argilla.io>
Loading