185 lines
7.0 KiB
Python
185 lines
7.0 KiB
Python
from datasets import load_dataset, Audio, Dataset
|
|
import soundfile as sf
|
|
import requests
|
|
import os
|
|
from tqdm import tqdm
|
|
import pandas as pd
|
|
import json
|
|
import pyarrow as pa
|
|
import pyarrow.parquet as pq
|
|
import numpy as np
|
|
from huggingface_hub import HfApi, create_repo
|
|
|
|
# Load the dataset with audio decoding
|
|
print("Loading dataset...")
|
|
ds = load_dataset(
|
|
"Ashegh-Sad-Warrior/Persian_Common_Voice_17_0",
|
|
split="validated[:5]",
|
|
streaming=False
|
|
).cast_column("audio", Audio(sampling_rate=16000))
|
|
|
|
output_dir = "confirmed_dataset"
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
confirmed = []
|
|
|
|
API_URL = "http://localhost:5000/batch_confirm"
|
|
batch_size = 8
|
|
|
|
# Hugging Face configuration
|
|
HF_DATASET_NAME = "dpr2000/persian-cv17-confirmed" # Change this to your desired dataset name
|
|
HF_PRIVATE = True # Set to True for private dataset
|
|
|
|
def save_flac(audio_array, path):
|
|
sf.write(path, audio_array, 16000, format="FLAC")
|
|
|
|
print("Processing batches...")
|
|
for i in tqdm(range(0, len(ds), batch_size)):
|
|
batch = ds[i:i+batch_size]
|
|
files = {}
|
|
references = []
|
|
temp_flacs = []
|
|
audio_arrays = []
|
|
# Fix: batch is a dict of lists
|
|
for j in range(len(batch["audio"])):
|
|
audio = batch["audio"][j]
|
|
flac_path = f"temp_{i+j}.flac"
|
|
save_flac(audio["array"], flac_path)
|
|
files[f"audio{j}"] = open(flac_path, "rb")
|
|
references.append(batch["sentence"][j])
|
|
temp_flacs.append(flac_path)
|
|
audio_arrays.append(audio["array"]) # Store the array for confirmed
|
|
data = {"references": json.dumps(references)}
|
|
try:
|
|
response = requests.post(API_URL, files=files, data=data, timeout=120)
|
|
if response.status_code == 200:
|
|
resp_json = response.json()
|
|
if "results" in resp_json:
|
|
results = resp_json["results"]
|
|
else:
|
|
print(f"Batch {i} failed: 'results' key missing in response: {resp_json}")
|
|
results = [None] * len(references)
|
|
else:
|
|
print(f"Batch {i} failed: HTTP {response.status_code} - {response.text}")
|
|
results = [None] * len(references)
|
|
except Exception as e:
|
|
print(f"Batch {i} failed: {e}")
|
|
results = [None] * len(references)
|
|
for j, result in enumerate(results):
|
|
if result and result.get("confirmed"):
|
|
# Save confirmed audio array and transcription
|
|
confirmed.append({"audio": audio_arrays[j], "transcription": references[j]})
|
|
os.remove(temp_flacs[j])
|
|
else:
|
|
os.remove(temp_flacs[j])
|
|
for f in files.values():
|
|
f.close()
|
|
|
|
# Save confirmed data using sharding approach
|
|
if confirmed:
|
|
print(f"\n🔄 Saving {len(confirmed)} confirmed samples...")
|
|
|
|
# Convert confirmed data to HuggingFace dataset format
|
|
def extract_minimal(example):
|
|
# Convert float32 audio (range -1.0 to 1.0) to int16 (range -32768 to 32767)
|
|
audio_float32 = np.array(example["audio"], dtype=np.float32)
|
|
# Ensure audio is in valid range and scale to int16
|
|
audio_float32 = np.clip(audio_float32, -1.0, 1.0)
|
|
audio_int16 = (audio_float32 * 32767).astype(np.int16)
|
|
return {
|
|
"audio": audio_int16.tobytes(), # Store as int16 bytes, compatible with Whisper
|
|
"text": example["transcription"]
|
|
}
|
|
|
|
# Create dataset from confirmed samples
|
|
confirmed_dataset = Dataset.from_list(confirmed)
|
|
confirmed_dataset = confirmed_dataset.map(extract_minimal, remove_columns=confirmed_dataset.column_names)
|
|
|
|
# Sharding parameters
|
|
num_shards = min(1, len(confirmed)) # Don't create more shards than samples
|
|
shard_size = len(confirmed_dataset) // num_shards + 1
|
|
|
|
# Write each shard separately
|
|
for i in range(num_shards):
|
|
start = i * shard_size
|
|
end = min(len(confirmed_dataset), (i + 1) * shard_size)
|
|
|
|
if start >= len(confirmed_dataset):
|
|
break
|
|
|
|
shard = confirmed_dataset.select(range(start, end))
|
|
table = pa.Table.from_pandas(shard.to_pandas()) # Convert to PyArrow table
|
|
|
|
shard_path = os.path.join(output_dir, f"confirmed_shard_{i:02}.parquet")
|
|
|
|
pq.write_table(
|
|
table,
|
|
shard_path,
|
|
compression="zstd",
|
|
compression_level=22, # Maximum compression
|
|
use_dictionary=True,
|
|
version="2.6"
|
|
)
|
|
|
|
print(f"🔹 Shard {i+1}/{num_shards}: {len(shard)} samples saved")
|
|
|
|
print(f"\n✅ All confirmed data saved in {num_shards} shards in `{output_dir}/`")
|
|
|
|
# Push to Hugging Face Hub
|
|
print(f"\n🚀 Pushing dataset to Hugging Face Hub as '{HF_DATASET_NAME}'...")
|
|
try:
|
|
# Initialize HF API with token
|
|
api = HfApi(token=os.getenv("HF_TOKEN"))
|
|
|
|
# Create the repository (private if specified)
|
|
repo_created = False
|
|
try:
|
|
create_repo(
|
|
repo_id=HF_DATASET_NAME,
|
|
repo_type="dataset",
|
|
private=HF_PRIVATE,
|
|
exist_ok=True
|
|
)
|
|
print(f"✅ Repository '{HF_DATASET_NAME}' created/verified")
|
|
repo_created = True
|
|
except Exception as e:
|
|
print(f"⚠️ Repository creation failed: {e}")
|
|
print("💡 Please create the repository manually on Hugging Face Hub first")
|
|
print(f"💡 Or change HF_DATASET_NAME to use your own username")
|
|
print("💡 Skipping upload due to repository creation failure")
|
|
|
|
if not repo_created:
|
|
print("💡 Skipping upload due to repository creation failure")
|
|
else:
|
|
# Create dataset info file
|
|
dataset_info = {
|
|
"dataset_name": HF_DATASET_NAME,
|
|
"description": "Persian Common Voice confirmed samples for Whisper fine-tuning",
|
|
"total_samples": len(confirmed),
|
|
"num_shards": num_shards,
|
|
"audio_format": "int16 PCM, 16kHz",
|
|
"columns": ["audio", "text"],
|
|
"source_dataset": "Ashegh-Sad-Warrior/Persian_Common_Voice_17_0",
|
|
"processing": "Vosk API batch confirmation"
|
|
}
|
|
|
|
# Save dataset info to output directory
|
|
info_path = os.path.join(output_dir, "dataset_info.json")
|
|
with open(info_path, 'w', encoding='utf-8') as f:
|
|
json.dump(dataset_info, f, indent=2, ensure_ascii=False)
|
|
|
|
# Upload entire folder using upload_folder
|
|
api.upload_folder(
|
|
folder_path=output_dir,
|
|
repo_id=HF_DATASET_NAME,
|
|
repo_type="dataset",
|
|
)
|
|
|
|
print(f"🎉 Dataset successfully pushed to: https://huggingface.co/datasets/{HF_DATASET_NAME}")
|
|
|
|
except Exception as e:
|
|
print(f"❌ Failed to push to Hugging Face: {e}")
|
|
print("💡 Make sure you have HF_TOKEN environment variable set")
|
|
print("💡 Set it with: export HF_TOKEN=your_token_here")
|
|
|
|
else:
|
|
print("❌ No confirmed samples to save") |