Skip to content

Commit 918668f

Browse files
authored
Merge pull request #198 from X-LANCE/dev-slam-omni
Fix: Update support for jsonl format
2 parents 43b5293 + 5cb725b commit 918668f

17 files changed

+57
-32
lines changed

examples/s2s/README.md

+5-2
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,12 @@ ds = load_dataset("DATASET_NAME")
4141
### JSONL
4242
We also support JSONL format for its concise structure. Below is an example:
4343
```jsonl
44-
{"key": "1", "source_wav": "/xxx/1.wav", "source_text": "Can you recommend some Chinese food for me?", "target_wav": "/xxx/1.wav", "target_text": "Sure! I recommend trying dumplings, Peking duck, and mapo tofu for a mix of flavors and textures in Chinese cuisine. These dishes offer a good balance of savory, spicy, and crispy elements."}
44+
{"key": "1", "source_wav": "/xxx/1.wav", "source_text": "Can you recommend some Chinese food for me?", "target_token": [742, 383, 455, 619, 180], "target_text": "Sure! I recommend trying dumplings, Peking duck, and mapo tofu for a mix of flavors and textures in Chinese cuisine. These dishes offer a good balance of savory, spicy, and crispy elements."}
4545
```
4646

47+
🔔**Update**:
48+
We now use `target_token` to replace the `target_wav` field. When using your own data, you need to generate the corresponding audio response tokens yourself (e.g., using [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) tokens in SLAM-Omni).
49+
4750
## Checkpoints
4851
We reproduced the single-stage fine-tuning results of SLAM-Omni with a group size of **3**. The following checkpoints are available for download:
4952
- [Single-Round Dialogue (English)](https://drive.google.com/drive/folders/1ZmM1h5ZTvS-piuN-msmctmZdi51GWLAu?usp=sharing): Trained on VoiceAssistant-400K.
@@ -144,4 +147,4 @@ Mini-Omni:
144147

145148

146149
## License
147-
Our code is released under MIT License. The Chinese dialogue model is licensed under GPL-3.0 due to its use of Belle data and is intended for research purposes only.
150+
Our code is released under MIT License. The Chinese dialogue model is licensed under GPL-3.0 due to its use of Belle data and is intended for research purposes only.

examples/s2s/demo/demo_data/jsonl_demo-en.jsonl

+10
Large diffs are not rendered by default.

examples/s2s/demo/demo_data/jsonl_demo-zh.jsonl

+10
Large diffs are not rendered by default.

examples/s2s/demo/demo_data/jsonl_demo.jsonl

-6
This file was deleted.

examples/s2s/s2s_config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ class DataConfig:
189189
"help": "whether input is normalized, used for models such as wavlm"
190190
})
191191
seed: int = 42
192-
manifest_format: str = field(default="datasets", metadata={ "help": "alternative: jsonl" })
192+
manifest_format: str = field(default="parquet", metadata={ "help": "alternative: jsonl" })
193193
split_size: float = 0.1
194194

195195
vocab_config: VocabConfig = field(default_factory=VocabConfig)

examples/s2s/scripts/finetune/finetune_s2s.sh

+3-2
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@ num_latency_tokens=0 # number of delay tokens (in front of the ge
3232
do_layershift=false # if false, tokens in each layers use the same codebook, otherwise, use different codebooks
3333

3434
# dataset settings
35+
manifest_format=parquet # parquet or jsonl
3536
train_data_path=worstchan/VoiceAssistant-400K-SLAM-Omni
3637
val_data_path=worstchan/VoiceAssistant-400K-SLAM-Omni
37-
load_from_cache_file=true # set to true if you have already generated the cache file, otherwise set to false
38+
load_from_cache_file=true # set to true if you have already generated the cache file, otherwise set to false
3839

3940
# training settings
4041
batch_size_training=6
@@ -89,7 +90,7 @@ hydra.run.dir=$output_dir \
8990
++dataset_config.input_type=mel \
9091
++dataset_config.mel_size=$mel_size \
9192
++dataset_config.seed=42 \
92-
++dataset_config.manifest_format=datasets \
93+
++dataset_config.manifest_format=$manifest_format \
9394
++dataset_config.split_size=$split_size \
9495
++dataset_config.load_from_cache_file=$load_from_cache_file \
9596
++dataset_config.task_type=$task_type \

examples/s2s/scripts/finetune/finetune_s2s_group.sh

+3-2
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@ num_latency_tokens=0 # number of delay tokens (in front of the ge
3232
do_layershift=false # if false, tokens in each layers use the same codebook, otherwise, use different codebooks
3333

3434
# dataset settings
35+
manifest_format=parquet # parquet or jsonl
3536
train_data_path=worstchan/VoiceAssistant-400K-SLAM-Omni
3637
val_data_path=worstchan/VoiceAssistant-400K-SLAM-Omni
37-
load_from_cache_file=true # set to true if you have already generated the cache file, otherwise set to false
38+
load_from_cache_file=true # set to true if you have already generated the cache file, otherwise set to false
3839

3940
# training settings
4041
batch_size_training=6
@@ -96,7 +97,7 @@ hydra.run.dir=$output_dir \
9697
++dataset_config.input_type=mel \
9798
++dataset_config.mel_size=$mel_size \
9899
++dataset_config.seed=42 \
99-
++dataset_config.manifest_format=datasets \
100+
++dataset_config.manifest_format=$manifest_format \
100101
++dataset_config.split_size=$split_size \
101102
++dataset_config.load_from_cache_file=$load_from_cache_file \
102103
++dataset_config.task_type=$task_type \

examples/s2s/scripts/finetune/mini-omni/finetune_s2s.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ mel_size=80 # 80 128 ( only whisper-large-v3 supports 128 )
2020
llm_dim=896 # 896 1536 2048 3584 -> 0.5B 1.5B 3B 7B
2121

2222
# dataset settings
23+
manifest_format=parquet # parquet or jsonl
2324
train_data_path="/valleblob/v-wenxichen/data/s2s/VoiceAssistant-400K"
2425
val_data_path="/valleblob/v-wenxichen/data/s2s/VoiceAssistant-400K"
2526
load_from_cache_file=false # set to true if you have already generated the cache file, otherwise set to false
@@ -75,7 +76,7 @@ hydra.run.dir=$output_dir \
7576
++dataset_config.input_type=mel \
7677
++dataset_config.mel_size=$mel_size \
7778
++dataset_config.seed=42 \
78-
++dataset_config.manifest_format=datasets \
79+
++dataset_config.manifest_format=$manifest_format \
7980
++dataset_config.split_size=$split_size \
8081
++dataset_config.load_from_cache_file=$load_from_cache_file \
8182
++dataset_config.task_type=$task_type \

examples/s2s/scripts/inference/inference_s2s_batch.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ ckpt_path=/valleblob/v-wenxichen/exp/s2s/s2s_train_v3-gpu16-btz3-lr5e-4-fp16-epo
3939
# val_data_path=/home/v-wenxichen/SLAM-LLM/examples/s2s/demo/data/${split}.jsonl
4040

4141
# huggingface dataset
42-
manifest_format=datasets
42+
manifest_format=parquet
4343
val_data_path="/valleblob/v-wenxichen/data/s2s/VoiceAssistant-400K-v1/test"
4444
load_from_cache_file=false
4545
dataset_sample_seed=777

examples/s2s/scripts/inference/mini-omni/inference_s2s_batch.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ split=test
2929
# val_data_path=/home/v-wenxichen/SLAM-LLM/examples/s2s/demo/data/${split}.jsonl
3030

3131
# huggingface dataset
32-
manifest_format=datasets
32+
manifest_format=parquet
3333
val_data_path="gpt-omni/VoiceAssistant-400K"
3434
load_from_cache_file=true
3535
dataset_sample_seed=777

examples/s2s/scripts/inference/mini-omni/inference_s2s_stream.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ split=test
2828
# val_data_path=/home/v-wenxichen/SLAM-LLM/examples/s2s/demo/data/${split}.jsonl
2929

3030
# huggingface dataset
31-
manifest_format=datasets
31+
manifest_format=parquet
3232
val_data_path="gpt-omni/VoiceAssistant-400K"
3333
load_from_cache_file=true
3434
dataset_sample_seed=1234

examples/s2s/scripts/inference/mini-omni/inference_tts.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ split=test
2525
# val_data_path=/home/v-wenxichen/SLAM-LLM/examples/s2s/demo/data/${split}.jsonl
2626

2727
# huggingface dataset
28-
manifest_format=datasets
28+
manifest_format=parquet
2929
val_data_path="gpt-omni/VoiceAssistant-400K"
3030
load_from_cache_file=true
3131
dataset_sample_seed=1234

examples/s2s/scripts/pretrain/pretrain_asr.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ num_latency_tokens=0 # number of delay tokens (in front of the ge
3232
do_layershift=false # if false, tokens in each layers use the same codebook, otherwise, use different codebooks
3333

3434
# dataset settings
35+
manifest_format=parquet # parquet or jsonl
3536
train_data_path=worstchan/VoiceAssistant-400K-SLAM-Omni
3637
val_data_path=worstchan/VoiceAssistant-400K-SLAM-Omni
3738
load_from_cache_file=true # set to true if you have already generated the cache file, otherwise set to false
@@ -96,7 +97,7 @@ hydra.run.dir=$output_dir \
9697
++dataset_config.input_type=mel \
9798
++dataset_config.mel_size=$mel_size \
9899
++dataset_config.seed=42 \
99-
++dataset_config.manifest_format=datasets \
100+
++dataset_config.manifest_format=$manifest_format \
100101
++dataset_config.split_size=$split_size \
101102
++dataset_config.load_from_cache_file=$load_from_cache_file \
102103
++dataset_config.task_type=$task_type \

examples/s2s/scripts/pretrain/pretrain_asr_debug.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ num_latency_tokens=0 # number of latency tokens (in front of the
3232
do_layershift=false # if false, tokens in each layers use the same codebook, otherwise, use different codebooks
3333

3434
# dataset settings
35+
manifest_format=parquet # parquet or jsonl
3536
train_data_path="/valleblob/v-wenxichen/data/s2s/parquet_data_test/en"
3637
val_data_path="/valleblob/v-wenxichen/data/s2s/parquet_data_test/en"
3738
load_from_cache_file=true # set to true if you have already generated the cache file, otherwise set to false
@@ -97,7 +98,7 @@ hydra.run.dir=$output_dir \
9798
++dataset_config.input_type=mel \
9899
++dataset_config.mel_size=$mel_size \
99100
++dataset_config.seed=42 \
100-
++dataset_config.manifest_format=datasets \
101+
++dataset_config.manifest_format=$manifest_format \
101102
++dataset_config.split_size=$split_size \
102103
++dataset_config.load_from_cache_file=$load_from_cache_file \
103104
++dataset_config.task_type=$task_type \

examples/s2s/scripts/pretrain/pretrain_tts.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ num_latency_tokens=0 # number of delay tokens (in front of the ge
2828
do_layershift=false # if false, tokens in each layers use the same codebook, otherwise, use different codebooks
2929

3030
# dataset settings
31+
manifest_format=parquet # parquet or jsonl
3132
train_data_path=worstchan/VoiceAssistant-400K-SLAM-Omni
3233
val_data_path=worstchan/VoiceAssistant-400K-SLAM-Omni
3334
load_from_cache_file=true # set to true if you have already generated the cache file, otherwise set to false
@@ -90,7 +91,7 @@ hydra.run.dir=$output_dir \
9091
++dataset_config.val_data_path=$val_data_path \
9192
++dataset_config.input_type=mel \
9293
++dataset_config.seed=42 \
93-
++dataset_config.manifest_format=datasets \
94+
++dataset_config.manifest_format=$manifest_format \
9495
++dataset_config.split_size=$split_size \
9596
++dataset_config.load_from_cache_file=$load_from_cache_file \
9697
++dataset_config.task_type=$task_type \

examples/s2s/scripts/pretrain/pretrain_tts_debug.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ num_latency_tokens=0 # number of latency tokens (in front of the
2626
do_layershift=false # if false, tokens in each layers use the same codebook, otherwise, use different codebooks
2727

2828
# dataset settings
29+
manifest_format=parquet # parquet or jsonl
2930
train_data_path="/valleblob/v-wenxichen/data/debug/1"
3031
val_data_path="/valleblob/v-wenxichen/data/debug/1"
3132
load_from_cache_file=true # set to true if you have already generated the cache file, otherwise set to false
@@ -82,7 +83,7 @@ hydra.run.dir=$output_dir \
8283
++dataset_config.val_data_path=$val_data_path \
8384
++dataset_config.input_type=mel \
8485
++dataset_config.seed=42 \
85-
++dataset_config.manifest_format=datasets \
86+
++dataset_config.manifest_format=$manifest_format \
8687
++dataset_config.split_size=$split_size \
8788
++dataset_config.load_from_cache_file=$load_from_cache_file \
8889
++dataset_config.task_type=$task_type \

examples/s2s/speech_dataset_s2s.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ def __init__(self,
3131
self.inference_mode = dataset_config.get("inference_mode", False)
3232
self.normalize = dataset_config.get("normalize", False)
3333
self.input_type = dataset_config.get("input_type", None)
34-
self.manifest_format = dataset_config.get("manifest_format", "datasets")
34+
self.manifest_format = dataset_config.get("manifest_format", "parquet")
3535
self.seed = dataset_config.get("seed", 42)
3636
self.split_size = dataset_config.get("split_size", 0.1)
3737
assert self.input_type in ["raw", "mel"], "input_type must be one of [raw, mel]"
38-
assert self.manifest_format in ["datasets", "jsonl"], "manifest_format must be one of [datasets, jsonl]"
38+
assert self.manifest_format in ["parquet", "jsonl"], "manifest_format must be one of [parquet, jsonl]"
3939

4040
# vocab config
4141
self.vocab_config = dataset_config.get("vocab_config", None)
@@ -88,7 +88,7 @@ def __init__(self,
8888
self.data_list = []
8989

9090
# TODO: design a better way to load data
91-
if self.manifest_format == "datasets":
91+
if self.manifest_format == "parquet":
9292
from datasets import load_dataset, load_from_disk
9393
if dataset_config.load_from_cache_file:
9494
ds = load_dataset(dataset_config.train_data_path) # load_from huggingface datasets
@@ -99,7 +99,7 @@ def __init__(self,
9999
self.data_list = train_val_split['train']
100100
else:
101101
self.data_list = train_val_split['test']
102-
else:
102+
elif self.manifest_format == "jsonl":
103103
if split == "train":
104104
with open(dataset_config.train_data_path, encoding='utf-8') as fin:
105105
for line in fin:
@@ -110,6 +110,8 @@ def __init__(self,
110110
for line in fin:
111111
data_dict = json.loads(line.strip())
112112
self.data_list.append(data_dict)
113+
else:
114+
raise ValueError("manifest_format must be one of [parquet, jsonl]")
113115

114116
def get_source_len(self, data_dict):
115117
return data_dict["source_len"]
@@ -120,16 +122,15 @@ def get_target_len(self, data_dict):
120122
def __len__(self):
121123
return len(self.data_list)
122124

123-
# NOTE: here datasets format is just for VoiceAssistant-400K dataset, and we only support the whisper format
124125
def extract_audio_feature(self, audio_path):
125126
# audio path is a dictionary, resample the audio to 16kHz
126-
if self.manifest_format == "datasets" and isinstance(audio_path, dict):
127+
if self.manifest_format == "parquet" and isinstance(audio_path, dict):
127128
audio_raw = audio_path['array']
128129
audio_raw_sr = audio_path['sampling_rate']
129130
if not isinstance(audio_raw, np.ndarray):
130131
audio_raw = np.array(audio_raw)
131132
audio_raw = librosa.resample(audio_raw, orig_sr=audio_raw_sr, target_sr=16000).astype(np.float32)
132-
elif self.manifest_format == "datasets" and (isinstance(audio_path, str) or isinstance(audio_path, list)):
133+
elif (self.manifest_format == "parquet" and (isinstance(audio_path, str) or isinstance(audio_path, list))) or (self.manifest_format == "jsonl" and isinstance(audio_path, list)):
133134
if self.code_type == "SNAC":
134135
audio_res, audio_length = get_snac_answer_token(audio_path)
135136
elif self.code_type == "CosyVoice":
@@ -233,7 +234,7 @@ def __getitem__(self, index):
233234
audio_length = 0
234235
target_audio_length = 0
235236

236-
if self.manifest_format == "datasets":
237+
if self.manifest_format == "parquet":
237238
source_audio = data_dict.get("question_audio", None)
238239
if self.code_type == "SNAC":
239240
target_audio = data_dict.get("answer_snac", None)
@@ -245,12 +246,12 @@ def __getitem__(self, index):
245246
key = source_audio['path']
246247
elif self.manifest_format == "jsonl":
247248
source_audio = data_dict.get("source_wav", None)
248-
target_audio = data_dict.get("target_wav", None)
249+
target_audio = data_dict.get("target_token", None)
249250
source_text = data_dict.get("source_text", None)
250251
target_text = data_dict.get("target_text", None)
251252
key = data_dict.get("key", None)
252253
else:
253-
raise ValueError("manifest_format must be one of [datasets, jsonl]")
254+
raise ValueError("manifest_format must be one of [parquet, jsonl]")
254255

255256
if task_type == "s2s" or task_type == "asr":
256257
audio_mel, audio_length = self.extract_audio_feature(source_audio)

0 commit comments

Comments
 (0)