# %% import matplotlib.pyplot as plt import numpy as np import pandas as pd import scipy.optimize rng = np.random.default_rng(1) # %% def pairs_in(pairs, S_pairs): query_pairs = pairs.copy() query_pairs = np.expand_dims(query_pairs, axis=1) # shape becomes (k, 1, 2) for efficient broadcasting, where k=pairs.shape[0] matches = np.all(query_pairs == S_pairs, axis=-1) # shape is (k, n), where n=S_pairs.shape[0] hits = np.any(matches, axis=-1) # shape is (k) return hits def draw_unique_pairs(S, k, max_tries=1e3): assert S.shape[0] % 2 == 0 out = np.zeros((k, 2), dtype=S.dtype) pairs_drawn = 0 tries = 0 while pairs_drawn != k: # Generate candidate pairs by permuting S, then generate pairs by adjacent elements S_perm = rng.permutation(S) S_perm_pairs = S_perm.reshape(S_perm.shape[0] // 2, 2) tries += 1 # Base case: accept all candidate pairs (up to k) if pairs_drawn == 0: pairs_drawn = min(k, S_perm_pairs.shape[0]) out[0:pairs_drawn] = S_perm_pairs[0:pairs_drawn] # Otherwise: accept only those candidate pairs that haven't yet been added to `out`. else: S_perm = rng.permutation(S) S_perm_pairs = S_perm.reshape(S_perm.shape[0] // 2, 2) S_perm_pairs_full = np.concatenate([S_perm_pairs, np.flip(S_perm_pairs, axis=-1)]) hits = pairs_in(S_perm_pairs_full, out[0:pairs_drawn]) hits = hits[:hits.shape[0] // 2] | hits[hits.shape[0] // 2:] # reduce hits to account for "reversed" pairs keeps = S_perm_pairs[~hits] if keeps.shape[0] > 0: pairs_to_add = min(k - pairs_drawn, keeps.shape[0]) out[pairs_drawn:pairs_drawn+pairs_to_add] = keeps[:pairs_to_add] pairs_drawn += pairs_to_add if tries >= max_tries: print(f'[WARN] N={S.shape[0]} k={k} \tOnly found {pairs_drawn} pairs after {tries} attempts.') break return out, tries >= max_tries def experiment(S, k): unique_pairs, timeout = draw_unique_pairs(S, k) unique_pairs_full = np.concatenate([unique_pairs, np.flip(unique_pairs, axis=-1)]) hits = pairs_in(unique_pairs_full, S.reshape(S.shape[0] // 2, 2)) return np.sum(hits), timeout # %% %%time Ns = np.arange(2, 51, step=2) max_k = 2000 iters_per_param = 20 results = [] for N in Ns: S = np.arange(1, N+1) timeout = False for k in range(1, max_k): for _ in range(iters_per_param): n_hits, timeout = experiment(S, k) if timeout: break results.append({'N': N, 'k': k, 'duplicates': n_hits}) if timeout: break # %% results_df = pd.DataFrame(results) results_df.to_csv('sim.csv', index=False) # %% results_df = pd.read_csv('sim.csv') # Reparameterize from "k" to "rounds of a full lunch tag" results_df = results_df.query('k % N == 0') results_df['round'] = results_df['k'] // results_df['N'] # Focus on whether any duplicates at all were encountered results_df['has_duplicate'] = results_df['duplicates'] != 0 # Aggregate mean and s.e.m. results_summary = results_df.groupby(['N', 'k', 'round'])['has_duplicate'].agg(['mean', 'sem']) results_summary # %% N = results_summary.index.get_level_values('N') k = results_summary.index.get_level_values('k') rounds = results_summary.index.get_level_values('round') mean = results_summary['mean'].values # %% %matplotlib widget fig = plt.figure() ax = plt.axes(projection='3d') ax.plot_trisurf(k[k < 120], N[k < 120], mean[k < 120], cmap='viridis', edgecolor='none') ax.scatter3D(k[k < 120], N[k < 120], mean[k < 120], c=mean[k < 120], cmap='viridis', vmin=0, vmax=1, s=5, alpha=1) ax.set_ylabel('N') ax.set_xlabel('k') # ax.set_ylabel('round') ax.set_zlabel('P(duplicate)') # %% max_k = results_df.reset_index(drop=True).groupby('N')['k'].agg('max') X = max_k.index.values Y = max_k.values f_p2 = lambda x, a, b, c: a + b * x + c * x**2 popt, pcov = scipy.optimize.curve_fit(f_p2, xdata=X, ydata=Y) print(popt) # %% # (-X + 0.5X^2) = X(X/2 - 1) print('Squared error', np.sum((Y - (X * (X/2 - 1)))**2)) plt.figure() max_k.plot(ylabel='maximum k', label='data', linewidth=3) plt.plot(X, X * (X/2 - 1), label='quadratic fit', linestyle='dashed', linewidth=3) plt.legend()