src.private_count_sketch.cs_client

  1from sympy import primerange
  2import random
  3import numpy as np
  4import importlib.util
  5import os
  6import argparse
  7import time
  8from progress.bar import Bar
  9from tabulate import tabulate
 10import sys
 11import pandas as pd
 12import pickle
 13import statistics
 14
 15from utils.utils import load_dataset, generate_error_table, generate_hash_functions, generate_hash_function_G, display_results
 16
 17class CSClient:
 18    def __init__(self, k, m, dataset, domain):
 19        self.k = k 
 20        self.m = m
 21        self.dataset = dataset
 22        self.domain = domain
 23        self.N = len(dataset)
 24        
 25        # Creation of the sketch matrix
 26        self.M = np.zeros((self.k, self.m))
 27
 28        # Definition of the hash family 3 by 3
 29        primes = list(primerange(10**6, 10**7))
 30        p = primes[random.randint(0, len(primes)-1)]
 31        self.H = generate_hash_functions(self.k,p, 3,self.m)
 32
 33        # Definition of the hash family 4 by 4
 34        prime = 2**31 -1
 35        a = random.randint(1, prime-1)
 36        b = random.randint(0, prime-1)
 37        c = random.randint(1, prime-1)
 38        d = random.randint(0, prime-1)
 39        self.G = generate_hash_function_G(self.k, p, a, b, c, d)
 40
 41    def client(self, d):
 42        j = random.randint(0, self.k-1)
 43        v = np.full(self.m, -1)
 44        selected_hash = self.H[j]
 45        v[selected_hash(d)] = 1
 46        return v, j
 47   
 48    def update_sketch_matrix(self, d):
 49        for i in range (self.k):
 50            selected_hash = self.H[i]
 51            hash_index = selected_hash(d)
 52            self.M[i ,hash_index] += self.G[i](d)
 53
 54    def estimate_client(self, d):
 55        vector_median = []
 56        for i in range(self.k):
 57            selected_hash = self.H[i]
 58            vector_median.append(self.M[i,selected_hash(d)] * self.G[i](d))
 59        median = statistics.median(vector_median)
 60        return median
 61    
 62    def server_simulator(self):
 63        bar = Bar('Processing client data', max=len(self.dataset), suffix='%(percent)d%%')
 64
 65        for d in self.dataset:
 66            self.update_sketch_matrix(d)
 67            bar.next()
 68        bar.finish()
 69
 70        F_estimated = {}
 71        bar = Bar('Obtaining histogram of estimated frequencies', max=len(self.domain), suffix='%(percent)d%%')
 72        for x in self.domain:
 73            F_estimated[x] = self.estimate_client(x)
 74            bar.next()
 75        bar.finish()
 76        return F_estimated
 77
 78def run_cs_client(k, m, d):
 79    # Load the dataset
 80    dataset_name = f"{d}_filtered"
 81    dataset, df, domain = load_dataset(dataset_name)
 82
 83    # Initialize the CMSClient
 84    PCMS = CSClient(k, m, dataset, domain)
 85
 86    # Simulate the server side
 87    f_estimated = PCMS.server_simulator()
 88
 89    # Save f_estimated to a file
 90    df_estimated = pd.DataFrame(list(f_estimated.items()), columns=['Element', 'Frequency'])
 91
 92    script_dir = os.path.dirname(os.path.abspath(__file__))
 93    output_dir = os.path.join(script_dir, "../../data/frequencies")
 94    df_estimated.to_csv(os.path.join(output_dir, f"{d}_freq_estimated_cms.csv"), index=False)
 95
 96    # Show the results
 97    data_table = display_results(df, f_estimated)
 98
 99    return data_table
100
101
102
103
104  
class CSClient:
18class CSClient:
19    def __init__(self, k, m, dataset, domain):
20        self.k = k 
21        self.m = m
22        self.dataset = dataset
23        self.domain = domain
24        self.N = len(dataset)
25        
26        # Creation of the sketch matrix
27        self.M = np.zeros((self.k, self.m))
28
29        # Definition of the hash family 3 by 3
30        primes = list(primerange(10**6, 10**7))
31        p = primes[random.randint(0, len(primes)-1)]
32        self.H = generate_hash_functions(self.k,p, 3,self.m)
33
34        # Definition of the hash family 4 by 4
35        prime = 2**31 -1
36        a = random.randint(1, prime-1)
37        b = random.randint(0, prime-1)
38        c = random.randint(1, prime-1)
39        d = random.randint(0, prime-1)
40        self.G = generate_hash_function_G(self.k, p, a, b, c, d)
41
42    def client(self, d):
43        j = random.randint(0, self.k-1)
44        v = np.full(self.m, -1)
45        selected_hash = self.H[j]
46        v[selected_hash(d)] = 1
47        return v, j
48   
49    def update_sketch_matrix(self, d):
50        for i in range (self.k):
51            selected_hash = self.H[i]
52            hash_index = selected_hash(d)
53            self.M[i ,hash_index] += self.G[i](d)
54
55    def estimate_client(self, d):
56        vector_median = []
57        for i in range(self.k):
58            selected_hash = self.H[i]
59            vector_median.append(self.M[i,selected_hash(d)] * self.G[i](d))
60        median = statistics.median(vector_median)
61        return median
62    
63    def server_simulator(self):
64        bar = Bar('Processing client data', max=len(self.dataset), suffix='%(percent)d%%')
65
66        for d in self.dataset:
67            self.update_sketch_matrix(d)
68            bar.next()
69        bar.finish()
70
71        F_estimated = {}
72        bar = Bar('Obtaining histogram of estimated frequencies', max=len(self.domain), suffix='%(percent)d%%')
73        for x in self.domain:
74            F_estimated[x] = self.estimate_client(x)
75            bar.next()
76        bar.finish()
77        return F_estimated
CSClient(k, m, dataset, domain)
19    def __init__(self, k, m, dataset, domain):
20        self.k = k 
21        self.m = m
22        self.dataset = dataset
23        self.domain = domain
24        self.N = len(dataset)
25        
26        # Creation of the sketch matrix
27        self.M = np.zeros((self.k, self.m))
28
29        # Definition of the hash family 3 by 3
30        primes = list(primerange(10**6, 10**7))
31        p = primes[random.randint(0, len(primes)-1)]
32        self.H = generate_hash_functions(self.k,p, 3,self.m)
33
34        # Definition of the hash family 4 by 4
35        prime = 2**31 -1
36        a = random.randint(1, prime-1)
37        b = random.randint(0, prime-1)
38        c = random.randint(1, prime-1)
39        d = random.randint(0, prime-1)
40        self.G = generate_hash_function_G(self.k, p, a, b, c, d)
k
m
dataset
domain
N
M
H
G
def client(self, d):
42    def client(self, d):
43        j = random.randint(0, self.k-1)
44        v = np.full(self.m, -1)
45        selected_hash = self.H[j]
46        v[selected_hash(d)] = 1
47        return v, j
def update_sketch_matrix(self, d):
49    def update_sketch_matrix(self, d):
50        for i in range (self.k):
51            selected_hash = self.H[i]
52            hash_index = selected_hash(d)
53            self.M[i ,hash_index] += self.G[i](d)
def estimate_client(self, d):
55    def estimate_client(self, d):
56        vector_median = []
57        for i in range(self.k):
58            selected_hash = self.H[i]
59            vector_median.append(self.M[i,selected_hash(d)] * self.G[i](d))
60        median = statistics.median(vector_median)
61        return median
def server_simulator(self):
63    def server_simulator(self):
64        bar = Bar('Processing client data', max=len(self.dataset), suffix='%(percent)d%%')
65
66        for d in self.dataset:
67            self.update_sketch_matrix(d)
68            bar.next()
69        bar.finish()
70
71        F_estimated = {}
72        bar = Bar('Obtaining histogram of estimated frequencies', max=len(self.domain), suffix='%(percent)d%%')
73        for x in self.domain:
74            F_estimated[x] = self.estimate_client(x)
75            bar.next()
76        bar.finish()
77        return F_estimated
def run_cs_client(k, m, d):
 79def run_cs_client(k, m, d):
 80    # Load the dataset
 81    dataset_name = f"{d}_filtered"
 82    dataset, df, domain = load_dataset(dataset_name)
 83
 84    # Initialize the CMSClient
 85    PCMS = CSClient(k, m, dataset, domain)
 86
 87    # Simulate the server side
 88    f_estimated = PCMS.server_simulator()
 89
 90    # Save f_estimated to a file
 91    df_estimated = pd.DataFrame(list(f_estimated.items()), columns=['Element', 'Frequency'])
 92
 93    script_dir = os.path.dirname(os.path.abspath(__file__))
 94    output_dir = os.path.join(script_dir, "../../data/frequencies")
 95    df_estimated.to_csv(os.path.join(output_dir, f"{d}_freq_estimated_cms.csv"), index=False)
 96
 97    # Show the results
 98    data_table = display_results(df, f_estimated)
 99
100    return data_table