Skip to content

Commit a238a70

Browse files
committed
update license, readme and example.py
1 parent 8df94de commit a238a70

4 files changed

Lines changed: 28 additions & 160 deletions

File tree

.zenodo.json

Lines changed: 0 additions & 66 deletions
This file was deleted.

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ SOFTWARE.
2727
This software is provided for academic and research purposes. If you use this
2828
code in your research, please cite the original paper:
2929

30-
Khosrowshahli, R., Kheiri, F., Bidgoli, A.A., Tizhoosh, H.R., Makrehchi, M., & Rahnamayan, S. (2024). Enhancing Image Retrieval Through Optimal Barcode Representation.
30+
Khosrowshahli, R., Kheiri, F., Bidgoli, A.A., Tizhoosh, H.R., Makrehchi, M., & Rahnamayan, S. (2025). Enhancing Image Retrieval Through Optimal Barcode Representation.
3131

3232
## Data Usage
3333

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ python main.py
8686
#### Using CGA-dHash on TCGA Brain dataset:
8787

8888
```bash
89-
python main.py --dataset tcga_brain_kimianet --method CGA-dHash --k 10 --cga_n_gen 50
89+
python main.py --dataset tcga_brain_kimianet --method CGA-dHash --k 10 --cga_n_gen 100
9090
```
9191

9292
#### Using neural network methods:
@@ -113,7 +113,7 @@ python main.py --dataset cifar10 --method CGA-dHash --download
113113
- `--method`: Barcoding method (default: `CGA-dHash`)
114114
- `--k`: Number of nearest neighbors for evaluation (default: 10)
115115
- `--n_bits`: Number of bits for hash codes (default: 128)
116-
- `--feature_selection`: Enable feature selection
116+
- `--feature_selection`: Enable feature selection (only works for TCGA)
117117
- `--download`: Automatically download datasets if they don't exist
118118
- `--cga_n_gen`: Number of generations for CGA (default: 100)
119119
- `--cga_pop_size`: Population size for CGA (default: 100)

example_usage.py renamed to example.py

Lines changed: 25 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Example usage of Deep Feature Barcoding with Combinatorial Genetic Algorithm
33
44
This script demonstrates how to use the various barcoding methods
5-
with synthetic data for testing and development purposes.
5+
with real Fashion-MNIST dataset for testing and evaluation.
66
"""
77

88
import numpy as np
@@ -12,75 +12,10 @@
1212
# Add src to path to import modules
1313
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
1414

15-
from src.utils import evaluate_retrieval, setup_seed
15+
from src.utils import evaluate_retrieval, setup_seed, load_dataset
1616
from src.methods import CGA, AHash, DHash, MinMax, DFT, ITQ, LSH
1717

1818

19-
def generate_synthetic_data(n_samples=1000, n_features=100, n_classes=5, seed=42):
20-
"""
21-
Generate synthetic dataset for testing barcoding methods.
22-
23-
Args:
24-
n_samples: Number of samples to generate
25-
n_features: Number of features per sample
26-
n_classes: Number of classes
27-
seed: Random seed for reproducibility
28-
29-
Returns:
30-
Tuple of (features, labels)
31-
"""
32-
np.random.seed(seed)
33-
34-
# Generate random features with some class structure
35-
features = []
36-
labels = []
37-
38-
for class_id in range(n_classes):
39-
# Generate class-specific mean
40-
class_mean = np.random.randn(n_features) * 2
41-
42-
# Generate samples for this class
43-
n_class_samples = n_samples // n_classes
44-
if class_id < n_samples % n_classes:
45-
n_class_samples += 1
46-
47-
class_features = np.random.randn(n_class_samples, n_features) + class_mean
48-
class_labels = np.full(n_class_samples, class_id)
49-
50-
features.append(class_features)
51-
labels.append(class_labels)
52-
53-
features = np.vstack(features)
54-
labels = np.hstack(labels)
55-
56-
# Shuffle the data
57-
shuffle_idx = np.random.permutation(len(features))
58-
features = features[shuffle_idx]
59-
labels = labels[shuffle_idx]
60-
61-
return features, labels
62-
63-
64-
def split_data(features, labels, train_ratio=0.6, val_ratio=0.2, test_ratio=0.2):
65-
"""Split data into train, validation, and test sets."""
66-
n_samples = len(features)
67-
n_train = int(n_samples * train_ratio)
68-
n_val = int(n_samples * val_ratio)
69-
70-
train_features = features[:n_train]
71-
train_labels = labels[:n_train]
72-
73-
val_features = features[n_train:n_train + n_val]
74-
val_labels = labels[n_train:n_train + n_val]
75-
76-
test_features = features[n_train + n_val:]
77-
test_labels = labels[n_train + n_val:]
78-
79-
return (train_features, train_labels,
80-
val_features, val_labels,
81-
test_features, test_labels)
82-
83-
8419
def run_barcoding_example(method_name, barcoder, train_features, train_labels,
8520
test_features, test_labels, k=5):
8621
"""
@@ -133,29 +68,27 @@ def main():
13368
# Set random seed for reproducibility
13469
setup_seed(42)
13570

136-
# Generate synthetic data
137-
print("\nGenerating synthetic dataset...")
138-
features, labels = generate_synthetic_data(
139-
n_samples=1000,
140-
n_features=64,
141-
n_classes=5,
142-
seed=42
143-
)
144-
145-
print(f"Generated dataset: {features.shape[0]} samples, {features.shape[1]} features, {len(np.unique(labels))} classes")
146-
147-
# Split data
148-
print("Splitting data into train/val/test sets...")
149-
train_features, train_labels, val_features, val_labels, test_features, test_labels = split_data(
150-
features, labels
151-
)
152-
153-
print(f"Train: {len(train_features)} samples")
154-
print(f"Validation: {len(val_features)} samples")
155-
print(f"Test: {len(test_features)} samples")
71+
# Load Fashion-MNIST dataset
72+
print("\nLoading Fashion-MNIST dataset...")
73+
try:
74+
train_features, train_labels, val_features, val_labels, test_features, test_labels = load_dataset(
75+
dataset_name="fashion",
76+
download=True # Auto-download if not available
77+
)
78+
79+
print(f"Dataset loaded successfully!")
80+
print(f"Train: {len(train_features)} samples, {train_features.shape[1]} features")
81+
print(f"Validation: {len(val_features)} samples")
82+
print(f"Test: {len(test_features)} samples")
83+
print(f"Number of classes: {len(np.unique(train_labels))}")
84+
85+
except Exception as e:
86+
print(f"Error loading Fashion-MNIST dataset: {str(e)}")
87+
print("Please check if the dataset is available or try downloading it manually.")
88+
return
15689

15790
# Initialize barcoding methods
158-
num_features = features.shape[1]
91+
num_features = train_features.shape[1]
15992
num_bits = 32 # Using smaller number of bits for faster computation in example
16093
k = 5 # Number of neighbors for evaluation
16194

@@ -211,7 +144,7 @@ def main():
211144

212145
# Print summary results
213146
print("\n" + "=" * 50)
214-
print("SUMMARY RESULTS")
147+
print("SUMMARY RESULTS - Fashion-MNIST Dataset")
215148
print("=" * 50)
216149
print(f"{'Method':<15} {'F1':<8} {'Prec@{k}':<8} {'mAP':<8}".format(k=k))
217150
print("-" * 50)
@@ -223,8 +156,9 @@ def main():
223156
print(f"{method_name:<15} {'ERROR':<8} {'ERROR':<8} {'ERROR':<8}")
224157

225158
print("\nExample completed successfully!")
226-
print("\nNote: This example uses synthetic data and reduced parameters for speed.")
227-
print("For real experiments, use larger populations, more generations, and real datasets.")
159+
print("\nNote: This example uses Fashion-MNIST dataset with reduced parameters for speed.")
160+
print("For real experiments, use larger populations, more generations, and other datasets.")
161+
print("\nAvailable datasets: fashion, cifar10, cifar100, covid19, and various TCGA medical datasets.")
228162

229163

230164
if __name__ == "__main__":

0 commit comments

Comments
 (0)