Files
vosk-datacleaner/vosk/test_files/batch_confirm_hf.py
2025-07-31 17:35:08 +03:30

190 lines
6.9 KiB
Python

from datasets import load_dataset, Audio, Dataset
import soundfile as sf
import requests
import os
from tqdm import tqdm
import pandas as pd
import json
import pyarrow as pa
import pyarrow.parquet as pq
import numpy as np
from huggingface_hub import HfApi, create_repo
# Load the dataset with audio decoding
print("Loading dataset...")
ds = load_dataset(
"Ashegh-Sad-Warrior/Persian_Common_Voice_17_0",
split="validated[:500]",
streaming=False
).cast_column("audio", Audio(sampling_rate=16000))
output_dir = "confirmed_dataset"
os.makedirs(output_dir, exist_ok=True)
confirmed = []
API_URL = "http://localhost:5000/batch_confirm"
batch_size = 8
# Hugging Face configuration
HF_DATASET_NAME = "dpr2000/persian-cv17-confirmed" # Change this to your desired dataset name
HF_PRIVATE = True # Set to True for private dataset
def save_flac(audio_array, path):
sf.write(path, audio_array, 16000, format="FLAC")
print("Processing batches...")
for i in tqdm(range(0, len(ds), batch_size)):
batch = ds[i:i+batch_size]
files = {}
references = []
temp_flacs = []
audio_arrays = []
# Fix: batch is a dict of lists
for j in range(len(batch["audio"])):
audio = batch["audio"][j]
flac_path = f"temp_{i+j}.flac"
save_flac(audio["array"], flac_path)
files[f"audio{j}"] = open(flac_path, "rb")
references.append(batch["sentence"][j])
temp_flacs.append(flac_path)
audio_arrays.append(audio["array"]) # Store the array for confirmed
data = {"references": json.dumps(references)}
try:
response = requests.post(API_URL, files=files, data=data, timeout=120)
if response.status_code == 200:
resp_json = response.json()
if "results" in resp_json:
results = resp_json["results"]
else:
print(f"Batch {i} failed: 'results' key missing in response: {resp_json}")
results = [None] * len(references)
else:
print(f"Batch {i} failed: HTTP {response.status_code} - {response.text}")
results = [None] * len(references)
except Exception as e:
print(f"Batch {i} failed: {e}")
results = [None] * len(references)
for j, result in enumerate(results):
if result and result.get("confirmed"):
# Save confirmed audio array and transcription
confirmed.append({"audio": audio_arrays[j], "transcription": references[j]})
os.remove(temp_flacs[j])
else:
os.remove(temp_flacs[j])
for f in files.values():
f.close()
# Save confirmed data using sharding approach
if confirmed:
print(f"\n🔄 Saving {len(confirmed)} confirmed samples...")
# Convert confirmed data to HuggingFace dataset format
def extract_minimal(example):
# Convert float32 audio (range -1.0 to 1.0) to int16 (range -32768 to 32767)
audio_float32 = np.array(example["audio"], dtype=np.float32)
# Ensure audio is in valid range and scale to int16
audio_float32 = np.clip(audio_float32, -1.0, 1.0)
audio_int16 = (audio_float32 * 32767).astype(np.int16)
return {
"audio": audio_int16.tobytes(), # Store as int16 bytes, compatible with Whisper
"text": example["transcription"]
}
# Create dataset from confirmed samples
confirmed_dataset = Dataset.from_list(confirmed)
confirmed_dataset = confirmed_dataset.map(extract_minimal, remove_columns=confirmed_dataset.column_names)
# Sharding parameters
num_shards = min(1, len(confirmed)) # Don't create more shards than samples
shard_size = len(confirmed_dataset) // num_shards + 1
# Write each shard separately
for i in range(num_shards):
start = i * shard_size
end = min(len(confirmed_dataset), (i + 1) * shard_size)
if start >= len(confirmed_dataset):
break
shard = confirmed_dataset.select(range(start, end))
table = pa.Table.from_pandas(shard.to_pandas()) # Convert to PyArrow table
shard_path = os.path.join(output_dir, f"confirmed_shard_{i:02}.parquet")
pq.write_table(
table,
shard_path,
compression="zstd",
compression_level=22, # Maximum compression
use_dictionary=True,
version="2.6"
)
print(f"🔹 Shard {i+1}/{num_shards}: {len(shard)} samples saved")
print(f"\n✅ All confirmed data saved in {num_shards} shards in `{output_dir}/`")
# Push to Hugging Face Hub
print(f"\n🚀 Pushing dataset to Hugging Face Hub as '{HF_DATASET_NAME}'...")
try:
# Initialize HF API
api = HfApi()
# Create the repository (private if specified)
try:
create_repo(
repo_id=HF_DATASET_NAME,
repo_type="dataset",
private=HF_PRIVATE,
exist_ok=True
)
print(f"✅ Repository '{HF_DATASET_NAME}' created/verified")
except Exception as e:
print(f"⚠️ Repository creation: {e}")
# Upload all parquet files
for i in range(num_shards):
shard_path = os.path.join(output_dir, f"confirmed_shard_{i:02}.parquet")
if os.path.exists(shard_path):
api.upload_file(
path_or_fileobj=shard_path,
path_in_repo=f"confirmed_shard_{i:02}.parquet",
repo_id=HF_DATASET_NAME,
repo_type="dataset"
)
print(f"📤 Uploaded shard {i+1}/{num_shards}")
# Create dataset info file
dataset_info = {
"dataset_name": HF_DATASET_NAME,
"description": "Persian Common Voice confirmed samples for Whisper fine-tuning",
"total_samples": len(confirmed),
"num_shards": num_shards,
"audio_format": "int16 PCM, 16kHz",
"columns": ["audio", "text"],
"source_dataset": "Ashegh-Sad-Warrior/Persian_Common_Voice_17_0",
"processing": "Vosk API batch confirmation"
}
# Upload dataset info
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(dataset_info, f, indent=2, ensure_ascii=False)
info_path = f.name
api.upload_file(
path_or_fileobj=info_path,
path_in_repo="dataset_info.json",
repo_id=HF_DATASET_NAME,
repo_type="dataset"
)
os.unlink(info_path)
print(f"🎉 Dataset successfully pushed to: https://huggingface.co/datasets/{HF_DATASET_NAME}")
except Exception as e:
print(f"❌ Failed to push to Hugging Face: {e}")
print("💡 Make sure you're logged in with: huggingface-cli login")
else:
print("❌ No confirmed samples to save")