| | import gradio as gr |
| | from datasets import load_dataset, Dataset |
| | from difflib import ndiff |
| | import pandas as pd |
| | from gradio_huggingfacehub_search import HuggingfaceHubSearch |
| |
|
| | from semhash import SemHash |
| | from semhash.datamodels import DeduplicationResult |
| |
|
| | from model2vec import StaticModel |
| |
|
| | |
| | default_dataset_name = "SetFit/amazon_massive_scenario_en-US" |
| | default_dataset1_split = "train" |
| | default_dataset2_split = "test" |
| | default_text_column = "text" |
| | default_threshold = 0.9 |
| |
|
| | |
| | model = StaticModel.from_pretrained("minishlab/potion-base-8M") |
| |
|
| |
|
| | def display_word_differences(x: str, y: str) -> str: |
| | """ |
| | Display the word-level differences between two texts, formatted to avoid |
| | misinterpretation of Markdown syntax. |
| | """ |
| | diff = ndiff(x.split(), y.split()) |
| | formatted_diff = "\n".join(word for word in diff if word.startswith(("+", "-"))) |
| | return f"```\n{formatted_diff}\n```" |
| |
|
| |
|
| | def load_dataset_texts( |
| | dataset_name: str, dataset_split: str, text_column: str |
| | ) -> tuple[list[str], Dataset]: |
| | """Load texts from a specified dataset split.""" |
| | ds = load_dataset(dataset_name, split=dataset_split) |
| | return [example[text_column] for example in ds], ds |
| |
|
| |
|
| | def deduplicate_single_dataset( |
| | texts: list[str], threshold: float |
| | ) -> DeduplicationResult: |
| | """ |
| | Deduplicate within a single dataset using SemHash, treating each text |
| | as a raw string record. |
| | """ |
| | |
| | semhash = SemHash.from_records(records=texts, model=model) |
| | |
| | return semhash.self_deduplicate(threshold=threshold) |
| |
|
| |
|
| | def deduplicate_two_datasets( |
| | texts1: list[str], texts2: list[str], threshold: float |
| | ) -> DeduplicationResult: |
| | """Deduplicate dataset2 against dataset1, both as raw strings, using SemHash.""" |
| | |
| | semhash = SemHash.from_records(records=texts1, model=model) |
| | |
| | return semhash.deduplicate(records=texts2, threshold=threshold) |
| |
|
| |
|
| | def create_deduplicated_dataset( |
| | original_dataset: Dataset, deduplicated_texts: list[str], text_column: str |
| | ) -> Dataset: |
| | """Create a new dataset with only the deduplicated texts.""" |
| | |
| | text_to_row = {row[text_column]: row for row in original_dataset} |
| |
|
| | |
| | deduplicated_rows = [] |
| | for text in deduplicated_texts: |
| | if text in text_to_row: |
| | deduplicated_rows.append(text_to_row[text]) |
| |
|
| | return Dataset.from_list(deduplicated_rows) |
| |
|
| |
|
| | def perform_deduplication( |
| | deduplication_type: str, |
| | dataset1_name: str, |
| | dataset1_split: str, |
| | dataset1_text_column: str, |
| | dataset2_name: str = "", |
| | dataset2_split: str = "", |
| | dataset2_text_column: str = "", |
| | threshold: float = default_threshold, |
| | progress: gr.Progress = gr.Progress(track_tqdm=True), |
| | ): |
| | """ |
| | Perform deduplication on one or two datasets using SemHash. This function |
| | streams status updates to Gradio for user feedback. |
| | """ |
| | try: |
| | threshold = float(threshold) |
| |
|
| | |
| | texts1, dataset1 = load_dataset_texts( |
| | dataset1_name, dataset1_split, dataset1_text_column |
| | ) |
| |
|
| | if deduplication_type == "Single dataset": |
| | |
| | result = deduplicate_single_dataset(texts1, threshold=threshold) |
| |
|
| | |
| | for duprec in result.duplicates: |
| | duprec.duplicates.sort(key=lambda x: x[1]) |
| |
|
| | |
| | deduplicated_dataset = create_deduplicated_dataset( |
| | dataset1, result.deduplicated, dataset1_text_column |
| | ) |
| |
|
| | |
| | num_duplicates = len(result.duplicates) |
| | deduplicated_count = len(result.deduplicated) |
| | total_docs = len(texts1) |
| |
|
| | |
| | examples_table = None |
| | if num_duplicates > 0: |
| | |
| | duplicates_with_data = [ |
| | duprec for duprec in result.duplicates if duprec.duplicates |
| | ] |
| |
|
| | |
| | for duprec in result.duplicates: |
| | duprec.duplicates.sort(key=lambda x: x[1]) |
| |
|
| | if duplicates_with_data: |
| | |
| | table_data = [] |
| | for duprec in duplicates_with_data[:5]: |
| | dup_text = duprec.record |
| | orig_text, score = duprec.duplicates[0] |
| | table_data.append( |
| | [ |
| | orig_text[:200] + "..." |
| | if len(orig_text) > 200 |
| | else orig_text, |
| | dup_text[:200] + "..." |
| | if len(dup_text) > 200 |
| | else dup_text, |
| | f"{score:.4f}", |
| | ] |
| | ) |
| |
|
| | examples_table = pd.DataFrame( |
| | table_data, |
| | columns=["Original Text", "Duplicate Text", "Similarity Score"], |
| | ) |
| |
|
| | |
| | gr.Info( |
| | f"Deduplication completed! Found {num_duplicates} duplicates. " |
| | f"Dataset reduced from {total_docs} to {deduplicated_count} unique documents." |
| | ) |
| |
|
| | |
| | if examples_table is not None and not examples_table.empty: |
| | return deduplicated_dataset, gr.update( |
| | visible=True, value=examples_table |
| | ) |
| | else: |
| | return deduplicated_dataset, gr.update(visible=False) |
| |
|
| | else: |
| | |
| | texts2, dataset2 = load_dataset_texts( |
| | dataset2_name, dataset2_split, dataset2_text_column |
| | ) |
| |
|
| | result = deduplicate_two_datasets(texts1, texts2, threshold=threshold) |
| |
|
| | |
| | for duprec in result.duplicates: |
| | duprec.duplicates.sort(key=lambda x: x[1]) |
| |
|
| | |
| | deduplicated_dataset = create_deduplicated_dataset( |
| | dataset2, result.deduplicated, dataset2_text_column |
| | ) |
| |
|
| | num_duplicates = len(result.duplicates) |
| | total_docs2 = len(texts2) |
| | deduplicated_count = len(result.deduplicated) |
| |
|
| | |
| | examples_table = None |
| | if num_duplicates > 0: |
| | |
| | duplicates_with_data = [ |
| | duprec for duprec in result.duplicates if duprec.duplicates |
| | ] |
| | if duplicates_with_data: |
| | |
| | table_data = [] |
| | for duprec in duplicates_with_data[:5]: |
| | dup_text = duprec.record |
| | orig_text, score = duprec.duplicates[0] |
| | table_data.append( |
| | [ |
| | orig_text[:200] + "..." |
| | if len(orig_text) > 200 |
| | else orig_text, |
| | dup_text[:200] + "..." |
| | if len(dup_text) > 200 |
| | else dup_text, |
| | f"{score:.4f}", |
| | ] |
| | ) |
| |
|
| | examples_table = pd.DataFrame( |
| | table_data, |
| | columns=[ |
| | "Original Text (Dataset 1)", |
| | "Duplicate Text (Dataset 2)", |
| | "Similarity Score", |
| | ], |
| | ) |
| |
|
| | |
| | gr.Info( |
| | f"Deduplication completed! Found {num_duplicates} duplicates in Dataset 2. " |
| | f"Dataset reduced from {total_docs2} to {deduplicated_count} unique documents." |
| | ) |
| |
|
| | |
| | if examples_table is not None and not examples_table.empty: |
| | return deduplicated_dataset, gr.update( |
| | visible=True, value=examples_table |
| | ) |
| | else: |
| | return deduplicated_dataset, gr.update(visible=False) |
| |
|
| | except Exception as e: |
| | gr.Error(f"An error occurred during deduplication: {str(e)}") |
| | return None, gr.update(visible=False) |
| |
|
| |
|
| | def push_to_hub( |
| | deduplicated_dataset: Dataset, |
| | output_dataset_name: str, |
| | oauth_profile: gr.OAuthProfile | None, |
| | oauth_token: gr.OAuthToken | None, |
| | progress: gr.Progress = gr.Progress(), |
| | ) -> str: |
| | """Push the deduplicated dataset to Hugging Face Hub.""" |
| | if oauth_token is None: |
| | raise gr.Error("Please log in with Hugging Face to push datasets to the Hub.") |
| |
|
| | if not output_dataset_name.strip(): |
| | raise gr.Error("Please provide a dataset name.") |
| |
|
| | if deduplicated_dataset is None: |
| | raise gr.Error( |
| | "No deduplicated dataset available. Please run deduplication first." |
| | ) |
| |
|
| | try: |
| | progress(0.1, desc="Preparing dataset...") |
| |
|
| | |
| | username = oauth_profile.username if oauth_profile else None |
| | if "/" not in output_dataset_name and username: |
| | full_dataset_name = f"{username}/{output_dataset_name}" |
| | else: |
| | full_dataset_name = output_dataset_name |
| |
|
| | progress(0.3, desc="Pushing to Hub...") |
| |
|
| | |
| | deduplicated_dataset.push_to_hub( |
| | full_dataset_name, token=oauth_token.token, private=False |
| | ) |
| |
|
| | progress(1.0, desc="Complete!") |
| |
|
| | gr.Info( |
| | f"Successfully pushed deduplicated dataset with {len(deduplicated_dataset)} rows to the Hub!" |
| | ) |
| |
|
| | return ( |
| | f"✅ **Dataset published:** [{full_dataset_name}]" |
| | f"(https://huggingface.co/datasets/{full_dataset_name})" |
| | ) |
| |
|
| | except Exception as e: |
| | raise gr.Error(f"Failed to push dataset to Hub: {str(e)}") |
| |
|
| |
|
| | def get_user_info(oauth_profile: gr.OAuthProfile | None) -> str: |
| | """Display user login status.""" |
| | if oauth_profile is None: |
| | return "Not logged in. Please log in to push datasets to the Hub." |
| | return f"Logged in as: **{oauth_profile.username}**" |
| |
|
| |
|
| | def update_push_button_state(oauth_profile: gr.OAuthProfile | None): |
| | """Update the push button state based on login status.""" |
| | is_logged_in = oauth_profile is not None |
| | return gr.update(interactive=is_logged_in) |
| |
|
| |
|
| | |
| | with gr.Blocks( |
| | theme=gr.themes.Ocean(), css="#status_output { height: 50px; overflow: auto; }" |
| | ) as demo: |
| | gr.Markdown("# SemDedup-My-Dataset: Semantic Text Deduplication Using SemHash") |
| | gr.Markdown(""" |
| | This demo showcases **semantic deduplication** using [SemHash](https://github.com/MinishLab/semhash) for HuggingFace datasets, using a [Model2Vec](https://github.com/MinishLab/model2vec) encoder. |
| | It can be used to identify duplicate texts within a **single dataset** or across **two datasets**. |
| | You can adjust the similarity threshold to control the strictness of the deduplication. |
| | |
| | """) |
| |
|
| | deduplication_type = gr.Radio( |
| | choices=["Cross-dataset", "Single dataset"], |
| | label="Deduplication Type", |
| | value="Cross-dataset", |
| | ) |
| |
|
| | with gr.Row(): |
| | dataset1_name = HuggingfaceHubSearch( |
| | label="Dataset 1 Name", |
| | placeholder="Search for datasets on HuggingFace Hub", |
| | search_type="dataset", |
| | value=default_dataset_name, |
| | ) |
| | dataset1_split = gr.Textbox( |
| | value=default_dataset1_split, label="Dataset 1 Split" |
| | ) |
| | dataset1_text_column = gr.Textbox( |
| | value=default_text_column, label="Text Column Name" |
| | ) |
| |
|
| | dataset2_inputs = gr.Column(visible=True) |
| | with dataset2_inputs: |
| | with gr.Row(): |
| | dataset2_name = HuggingfaceHubSearch( |
| | label="Dataset 2 Name", |
| | placeholder="Search for datasets on HuggingFace Hub", |
| | search_type="dataset", |
| | value=default_dataset_name, |
| | ) |
| | dataset2_split = gr.Textbox( |
| | value=default_dataset2_split, label="Dataset 2 Split" |
| | ) |
| | dataset2_text_column = gr.Textbox( |
| | value=default_text_column, label="Text Column Name" |
| | ) |
| |
|
| | threshold = gr.Slider( |
| | 0.0, 1.0, value=default_threshold, label="Similarity Threshold" |
| | ) |
| |
|
| | with gr.Row(): |
| | compute_button = gr.Button("Deduplicate", variant="primary") |
| |
|
| | status_output = gr.Markdown(elem_id="status_output") |
| |
|
| | |
| | examples_table = gr.Dataframe( |
| | headers=["Original Text", "Duplicate Text", "Similarity Score"], |
| | datatype=["str", "str", "str"], |
| | ) |
| |
|
| | |
| | deduplicated_dataset_state = gr.State() |
| |
|
| | |
| | gr.Markdown("## Push Deduplicated Dataset to Hub") |
| | with gr.Row(): |
| | with gr.Column(): |
| | output_dataset_name = gr.Textbox( |
| | label="Output Dataset Name", |
| | placeholder="my-deduplicated-dataset", |
| | info="Will be saved as username/dataset-name", |
| | ) |
| | with gr.Column(): |
| | push_button = gr.Button( |
| | "Push to Hub", variant="secondary", interactive=False |
| | ) |
| | login_button = gr.LoginButton() |
| |
|
| | |
| | with gr.Row(): |
| | user_info = gr.Markdown() |
| | push_output = gr.Markdown() |
| |
|
| | |
| | login_button.activate() |
| | |
| | def update_visibility(choice: str): |
| | return gr.update(visible=(choice == "Cross-dataset")) |
| |
|
| | deduplication_type.change( |
| | update_visibility, inputs=deduplication_type, outputs=dataset2_inputs |
| | ) |
| |
|
| | |
| | demo.load(get_user_info, inputs=None, outputs=user_info) |
| | demo.load(update_push_button_state, inputs=None, outputs=push_button) |
| | login_button.click(get_user_info, inputs=None, outputs=user_info) |
| | login_button.click(update_push_button_state, inputs=None, outputs=push_button) |
| |
|
| | compute_button.click( |
| | fn=perform_deduplication, |
| | inputs=[ |
| | deduplication_type, |
| | dataset1_name, |
| | dataset1_split, |
| | dataset1_text_column, |
| | dataset2_name, |
| | dataset2_split, |
| | dataset2_text_column, |
| | threshold, |
| | ], |
| | outputs=[deduplicated_dataset_state, examples_table], |
| | ) |
| |
|
| | push_button.click( |
| | fn=push_to_hub, |
| | inputs=[ |
| | deduplicated_dataset_state, |
| | output_dataset_name, |
| | ], |
| | outputs=push_output, |
| | ) |
| |
|
| | demo.launch(ssr_mode=False) |
| |
|