JackyChunKit commited on
Commit
1dc597f
·
verified ·
1 Parent(s): 9b03bab

Upload rl_dataset.py

Browse files
Files changed (1) hide show
  1. rl_dataset.py +256 -0
rl_dataset.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+ import os
17
+ import re
18
+ from collections import defaultdict
19
+ from typing import List, Optional, Union
20
+
21
+ import datasets
22
+ import numpy as np
23
+ import torch
24
+ from omegaconf import DictConfig, ListConfig
25
+ from torch.utils.data import Dataset
26
+ from transformers import PreTrainedTokenizer, ProcessorMixin
27
+
28
+ import verl.utils.torch_functional as verl_F
29
+ from verl.utils.model import compute_position_id_with_mask
30
+
31
+
32
+ def collate_fn(data_list: list[dict]) -> dict:
33
+ tensors = defaultdict(list)
34
+ non_tensors = defaultdict(list)
35
+
36
+ for data in data_list:
37
+ for key, val in data.items():
38
+ if isinstance(val, torch.Tensor):
39
+ tensors[key].append(val)
40
+ else:
41
+ non_tensors[key].append(val)
42
+
43
+ for key, val in tensors.items():
44
+ tensors[key] = torch.stack(val, dim=0)
45
+
46
+ for key, val in non_tensors.items():
47
+ non_tensors[key] = np.array(val, dtype=object)
48
+
49
+ return {**tensors, **non_tensors}
50
+
51
+
52
+ class RLHFDataset(Dataset):
53
+ """
54
+ We assume the dataset contains a column that contains prompts and other information
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ data_files: Union[str, List[str]],
60
+ tokenizer: PreTrainedTokenizer,
61
+ config: DictConfig,
62
+ processor: Optional[ProcessorMixin] = None,
63
+ ):
64
+ if not isinstance(data_files, (List, ListConfig)):
65
+ data_files = [data_files]
66
+
67
+ self.data_files = copy.deepcopy(data_files)
68
+ self.original_data_files = copy.deepcopy(data_files) # use for resume
69
+ self.tokenizer = tokenizer
70
+ self.processor = processor
71
+ self.config = config
72
+
73
+ self.cache_dir = os.path.expanduser(config.get("cache_dir", "~/.cache/verl/rlhf"))
74
+ self.prompt_key = config.get("prompt_key", "prompt")
75
+ self.image_key = config.get("image_key", "images")
76
+ self.video_key = config.get("video_key", "videos")
77
+ self.max_prompt_length = config.get("max_prompt_length", 1024)
78
+
79
+ self.return_raw_chat = config.get("return_raw_chat", False)
80
+ self.truncation = config.get("truncation", "error")
81
+ self.filter_overlong_prompts = config.get("filter_overlong_prompts", True)
82
+
83
+ self.num_workers = config.get("filter_overlong_prompts_workers", max(1, os.cpu_count() // 4))
84
+ self.num_workers = min(self.num_workers, os.cpu_count())
85
+
86
+ # whether to store the dataset in state_dict()
87
+ # default not store
88
+ self.serialize_dataset = False
89
+ self._download()
90
+ self._read_files_and_tokenize()
91
+
92
+ def _download(self, use_origin_parquet=False):
93
+ from verl.utils.fs import copy_to_local
94
+
95
+ data_files = self.data_files if not use_origin_parquet else self.original_data_files
96
+ for i, parquet_file in enumerate(data_files):
97
+ self.data_files[i] = copy_to_local(src=parquet_file, cache_dir=self.cache_dir)
98
+
99
+ def _read_files_and_tokenize(self):
100
+ dataframes = []
101
+ for parquet_file in self.data_files:
102
+ # read parquet files and cache
103
+ dataframe = datasets.load_dataset("parquet", data_files=parquet_file)["train"]
104
+ dataframes.append(dataframe)
105
+ self.dataframe: datasets.Dataset = datasets.concatenate_datasets(dataframes)
106
+
107
+ print(f"dataset len: {len(self.dataframe)}")
108
+
109
+ # filter out too long prompts
110
+ if self.filter_overlong_prompts:
111
+ tokenizer = self.tokenizer
112
+ prompt_key = self.prompt_key
113
+ self.dataframe = self.dataframe.filter(
114
+ lambda doc: len(tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True))
115
+ <= self.max_prompt_length,
116
+ num_proc=self.num_workers,
117
+ desc=f"Filtering prompts longer than {self.max_prompt_length} tokens",
118
+ )
119
+
120
+ print(f"filter dataset len: {len(self.dataframe)}")
121
+
122
+ def resume_dataset_state(self):
123
+ self.serialize_dataset = not hasattr(self, "original_data_files")
124
+ # resume dataframe if not it's serialized in data.pt
125
+ if not self.serialize_dataset:
126
+ self._download(use_origin_parquet=True) # download and resume from original parquet files
127
+ self._read_files_and_tokenize()
128
+ else:
129
+ print(r"old dataloader ckpt file is used, please train from scratch for better ckpt performance")
130
+
131
+ def __len__(self):
132
+ return len(self.dataframe)
133
+
134
+ def _build_messages(self, example: dict):
135
+ messages: list = example.pop(self.prompt_key)
136
+
137
+ if self.image_key in example or self.video_key in example:
138
+ for message in messages:
139
+ content = message["content"]
140
+ content_list = []
141
+ for segment in re.split("(<image>|<video>)", content):
142
+ if segment == "<image>":
143
+ content_list.append({"type": "image"})
144
+ elif segment == "<video>":
145
+ content_list.append({"type": "video"})
146
+ else:
147
+ content_list.append({"type": "text", "text": segment})
148
+
149
+ message["content"] = content_list
150
+
151
+ return messages
152
+
153
+ def __getitem__(self, item):
154
+ """
155
+ Note that we also return the raw_input_ids so that it can be combined with other chat template
156
+ """
157
+ row_dict: dict = self.dataframe[item]
158
+ messages = self._build_messages(row_dict)
159
+ model_inputs = {}
160
+
161
+ if self.processor is not None:
162
+ from verl.utils.dataset.vision_utils import process_image, process_video
163
+
164
+ raw_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
165
+ multi_modal_data = {}
166
+
167
+ images = None
168
+ if self.image_key in row_dict:
169
+ images = [process_image(image) for image in row_dict.pop(self.image_key)]
170
+ multi_modal_data["image"] = images
171
+
172
+ videos = None
173
+ if self.video_key in row_dict:
174
+ videos = [process_video(video) for video in row_dict.pop(self.video_key)]
175
+ multi_modal_data["video"] = [video.numpy() for video in videos]
176
+
177
+ model_inputs = self.processor(text=[raw_prompt], images=images, videos=videos, return_tensors="pt")
178
+
179
+ input_ids = model_inputs.pop("input_ids")
180
+ attention_mask = model_inputs.pop("attention_mask")
181
+
182
+ if "second_per_grid_ts" in model_inputs:
183
+ model_inputs.pop("second_per_grid_ts")
184
+
185
+ # There's a trap here, multi_modal_inputs has to be a dict, not BatchFeature
186
+ row_dict["multi_modal_data"] = multi_modal_data
187
+ row_dict["multi_modal_inputs"] = dict(model_inputs)
188
+
189
+ # second_per_grid_ts isn't used for training, just for mrope
190
+ row_dict["multi_modal_inputs"].pop("second_per_grid_ts", None)
191
+
192
+ else:
193
+ raw_prompt = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
194
+ model_inputs = self.tokenizer(raw_prompt, return_tensors="pt", add_special_tokens=False)
195
+ input_ids = model_inputs.pop("input_ids")
196
+ attention_mask = model_inputs.pop("attention_mask")
197
+
198
+ input_ids, attention_mask = verl_F.postprocess_data(
199
+ input_ids=input_ids,
200
+ attention_mask=attention_mask,
201
+ max_length=self.max_prompt_length,
202
+ pad_token_id=self.tokenizer.pad_token_id,
203
+ left_pad=True,
204
+ truncation=self.truncation,
205
+ )
206
+
207
+ if self.processor is not None and self.processor.image_processor.__class__.__name__ == "Qwen2VLImageProcessor":
208
+ from verl.models.transformers.qwen2_vl import get_rope_index
209
+
210
+ position_ids = [
211
+ get_rope_index(
212
+ self.processor,
213
+ input_ids=input_ids[0],
214
+ image_grid_thw=model_inputs.get("image_grid_thw"),
215
+ video_grid_thw=model_inputs.get("video_grid_thw"),
216
+ second_per_grid_ts=model_inputs.get("second_per_grid_ts"),
217
+ attention_mask=attention_mask[0],
218
+ )
219
+ ] # (1, 3, seq_len)
220
+
221
+ else:
222
+ position_ids = compute_position_id_with_mask(attention_mask)
223
+
224
+ row_dict["input_ids"] = input_ids[0]
225
+ row_dict["attention_mask"] = attention_mask[0]
226
+ row_dict["position_ids"] = position_ids[0]
227
+
228
+ raw_prompt_ids = self.tokenizer.encode(raw_prompt, add_special_tokens=False)
229
+ if len(raw_prompt_ids) > self.max_prompt_length:
230
+ if self.truncation == "left":
231
+ raw_prompt_ids = raw_prompt_ids[-self.max_prompt_length :]
232
+ elif self.truncation == "right":
233
+ raw_prompt_ids = raw_prompt_ids[: self.max_prompt_length]
234
+ elif self.truncation == "error":
235
+ raise RuntimeError(f"Prompt length {len(raw_prompt_ids)} is longer than {self.max_prompt_length}.")
236
+
237
+ row_dict["raw_prompt_ids"] = raw_prompt_ids
238
+ # encode prompts without chat template
239
+ if self.return_raw_chat:
240
+ row_dict["raw_prompt"] = messages
241
+
242
+ # add index for each prompt
243
+ index = row_dict.get("extra_info", {}).get("index", 0)
244
+ row_dict["index"] = index
245
+
246
+ return row_dict
247
+
248
+ def __getstate__(self):
249
+ if not self.serialize_dataset:
250
+ state = self.__dict__.copy()
251
+
252
+ if "dataframe" in state:
253
+ del state["dataframe"]
254
+ return state
255
+
256
+ return self.__dict__.copy()