Skip to content

Commit 8d2dc88

Browse files
authored
Merge pull request #179 from X-LANCE/yxdu
Yxdu
2 parents 6c26585 + f887fc7 commit 8d2dc88

File tree

7 files changed

+36
-124300
lines changed

7 files changed

+36
-124300
lines changed

examples/st_covost2/covost2_zh.jsonl

-15,529
This file was deleted.

examples/st_covost2/dataset/hf_dataset.py

+11-42
Original file line numberDiff line numberDiff line change
@@ -27,47 +27,12 @@ def __init__(self,
2727
super().__init__()
2828
self.mel_size = dataset_config.get("mel_size", 80) # 80 for whisper large v1 and v2, 128 for large v3
2929

30-
rank = dist.get_rank()
31-
32-
33-
data_name = "yxdu/covost2_en_x"
34-
local_dataset_path= data_name.split("/")[-1]+"_"+split+"_cache"
35-
36-
if os.path.exists(local_dataset_path):
37-
ds = load_from_disk(local_dataset_path)
38-
print(ds)
39-
else:
40-
if rank==0:
41-
ds = load_dataset(data_name, split=split)
42-
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
43-
print(ds)
44-
45-
46-
47-
def prepare_dataset(example):
48-
audio_raw = whisper.pad_or_trim(example["audio"]["array"])
49-
50-
audio_raw = torch.tensor(audio_raw, dtype=torch.float32)
51-
audio_mel = whisper.log_mel_spectrogram(audio_raw, n_mels=self.mel_size).permute(1, 0)
52-
53-
example["audio_mel"] = audio_mel
54-
55-
56-
return example
57-
58-
ds = ds.map(prepare_dataset, remove_columns="audio")
59-
60-
ds.save_to_disk(local_dataset_path)
61-
62-
dist.barrier()
63-
if rank != 0:
64-
if os.path.exists(local_dataset_path):
65-
ds = load_from_disk(local_dataset_path)
66-
else:
67-
raise FileNotFoundError("No Dataset。")
68-
69-
70-
30+
if split=="val":
31+
split="validation"
32+
ds = load_dataset("yxdu/covost2_en_x",split=split)
33+
ds = ds.cast_column("audio", Audio(sampling_rate=16000))
34+
print(ds)
35+
7136

7237
self.ds = ds
7338
self.tokenizer = tokenizer
@@ -111,8 +76,12 @@ def __getitem__(self, index):
11176
print(target)
11277
self.printed = True
11378

79+
audio_raw = whisper.pad_or_trim(data_dict["audio"]["array"])
80+
audio_raw = torch.tensor(audio_raw, dtype=torch.float32)
81+
audio_mel = whisper.log_mel_spectrogram(audio_raw, n_mels=self.mel_size).permute(1, 0)
82+
11483
if self.bf16:
115-
audio_mel = torch.tensor(data_dict["audio_mel"], dtype=torch.bfloat16)
84+
audio_mel = audio_mel.to(torch.bfloat16)
11685

11786

11887
if self.fix_length_audio > 0:

examples/st_covost2/inference_asr_batch.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,12 @@ def __len__(self):
7676
def Inference(kwargs: DictConfig):
7777

7878
# Update the configuration for the training and sharding process
79-
train_config, fsdp_config, model_config, log_config, dataset_config = kwargs.train_config, \
79+
train_config, fsdp_config, model_config, log_config, dataset_config,ckpt_path = kwargs.train_config, \
8080
kwargs.fsdp_config, \
8181
kwargs.model_config, \
8282
kwargs.log_config, \
83-
kwargs.dataset_config
83+
kwargs.dataset_config, \
84+
kwargs.ckpt_path
8485

8586
OmegaConf.set_struct(kwargs,False)
8687
del kwargs["train_config"]
@@ -114,8 +115,8 @@ def Inference(kwargs: DictConfig):
114115

115116
config = AutoConfig.from_pretrained("Qwen/Qwen2-7B") # 加载 Qwen2-7B 的配置
116117
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B")
117-
model = CustomSLM(config,ckpt_path="cotst/model.pt")
118-
118+
model = CustomSLM(config,ckpt_path=ckpt_path)
119+
# model = AutoModel.from_pretrained("/home/yxdu/hit/SLAM-LLM/examples/st_covost2/output/step_10/test")
119120

120121

121122
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # FIX(MZY): put the whole model to device.
@@ -143,6 +144,7 @@ def Inference(kwargs: DictConfig):
143144
batch_size=train_config.val_batch_size,
144145
drop_last=False,
145146
prefetch_factor=1000,
147+
persistent_workers=True,
146148
collate_fn=dataset_test.collator
147149
)
148150

examples/st_covost2/model/slm_model.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,10 @@ def forward(self,
6666
audio_mel = kwargs.get("audio_mel", None)
6767
audio_mel_post_mask = kwargs.get("audio_mel_post_mask", None) # 2x downsample for whisper
6868

69-
7069
encoder_outs = self.encoder(audio_mel.permute(0, 2, 1)).last_hidden_state # bs*seq*dim
7170
encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask)
7271

7372
input_ids = input_ids[:, 80:]
74-
7573
inputs_embeds = self.llm.model.embed_tokens(input_ids)
7674
inputs_embeds = torch.cat((encoder_outs, inputs_embeds), dim=1)
7775

@@ -80,14 +78,16 @@ def forward(self,
8078

8179

8280
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels,)
83-
acc = -1
84-
if self.metric:
85-
with torch.no_grad():
86-
preds = torch.argmax(input=model_outputs.logits, dim=-1)
87-
acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=-100)
88-
89-
90-
return model_outputs, acc
81+
82+
83+
with torch.no_grad():
84+
preds = torch.argmax(input=model_outputs.logits, dim=-1)
85+
acc = compute_accuracy(preds.detach()[:, :-1], labels.detach()[:, 1:], ignore_label=-100)
86+
print(acc)
87+
88+
return model_outputs
89+
90+
# return model_outputs, acc
9191

9292
@torch.no_grad()
9393
def generate(self,

examples/st_covost2/scripts/all.sh

+7-11
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# export TOKENIZERS_PARALLELISM=false
22
export WANDB_MODE=offline
33
# export HYDRA_FULL_ERROR=1
4-
4+
export CUDA_VISIBLE_DEVICES=0,1
55
if command -v nvidia-smi &> /dev/null; then
66
gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
77
if [ -n "$CUDA_VISIBLE_DEVICES" ]; then
@@ -15,7 +15,7 @@ current_dir=$(dirname "$current_script")
1515
code_dir=$(realpath "$current_dir/../../../../")
1616
cd ${code_dir}/SLAM-LLM
1717

18-
source=all
18+
source=zh
1919

2020
checkpoint_dir=${code_dir}/speech/data/qwen/spt-all-7B-4
2121
output_dir=${code_dir}/speech/data/qwen/cotst-all
@@ -24,11 +24,6 @@ encoder_path_hf=${code_dir}/speech/models/whisper-large-v3
2424
llm_path=${code_dir}/speech/models/Qwen2-7B
2525

2626

27-
#change your train data
28-
train_data_path=${code_dir}/SLAM-LLM/examples/st_covost2/test_st.jsonl
29-
val_data_path=${code_dir}/SLAM-LLM/examples/st_covost2/test_st.jsonl
30-
31-
3227

3328

3429
max_epoch=$(ls -d ${checkpoint_dir}/asr_epoch_*_step_* | sed -n 's/.*asr_epoch_\([0-9]*\)_step_\([0-9]*\).*/\1/p' | sort -n | tail -1)
@@ -40,7 +35,7 @@ final_path="${checkpoint_dir}/asr_epoch_${max_epoch}_step_${max_step}"
4035

4136

4237
ckpt_name=$final_path/model.pt
43-
38+
ckpt_name=/home/yxdu/hit/SLAM-LLM/cotst/model.pt
4439
# 使用find命令搜索所有.pt文件,并获取最后修改日期最晚的文件
4540

4641

@@ -62,7 +57,8 @@ hydra.run.dir=$output_dir \
6257
++model_config.encoder_dim=1280 \
6358
++model_config.encoder_projector=q-former \
6459
++model_config.query_len=80 \
65-
++dataset_config.dataset=st_dataset \
60+
++dataset_config.dataset=hf_dataset \
61+
++dataset_config.file=examples/st_covost2/dataset/hf_dataset.py:get_speech_dataset \
6662
++dataset_config.train_data_path=$train_data_path \
6763
++dataset_config.val_data_path=$val_data_path \
6864
++dataset_config.input_type=mel \
@@ -74,7 +70,7 @@ hydra.run.dir=$output_dir \
7470
++train_config.freeze_encoder=true \
7571
++train_config.freeze_llm=true \
7672
++train_config.batching_strategy=custom \
77-
++train_config.gradient_accumulation_steps=1 \
73+
++train_config.gradient_accumulation_steps=8 \
7874
++train_config.warmup_steps=1000 \
7975
++train_config.total_steps=1000000 \
8076
++train_config.lr=1e-5 \
@@ -101,7 +97,7 @@ torchrun \
10197
++fsdp_config.pure_bf16=true \
10298
++log_config.use_wandb=true \
10399
++log_config.wandb_project_name=cot \
104-
++train_config.validation_interval=100 \
100+
++train_config.validation_interval=10000 \
105101
++log_config.wandb_exp_name=all \
106102
++train_config.use_peft=false \
107103
$hydra_args

examples/st_covost2/scripts/infer_enzh.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
export MASTER_ADDR=localhost
22
export MASTER_PORT=12345
33
export WANDB_MODE=offline
4-
4+
export CUDA_VISIBLE_DEVICES=2,3
55
if command -v nvidia-smi &> /dev/null; then
66
gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
77
if [ -n "$CUDA_VISIBLE_DEVICES" ]; then
@@ -32,6 +32,7 @@ if [ ! -f "$ckpt_path" ]; then
3232
echo "Download ckpt..."
3333
git clone https://huggingface.co/yxdu/cotst
3434
fi
35+
3536
echo $ckpt_path
3637

3738

0 commit comments

Comments
 (0)