Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- %%%% BUILD DATASET %%%%
- nb_samples = 1000; input_dim = 1; output_dim = 2; nb_topics = 3; nb_iter = 25;
- x = -5 + 10*rand(nb_samples, input_dim+1);
- x(:,end) = 1;
- %T = rand(nb_topics, input_dim+1);
- T = 1*[-2, 0, 2;-1, 3, 1]';
- t = softmax(T*x')';
- experts = -1+2*rand(input_dim + 1, output_dim, nb_topics);
- experts(:, 1, :) = repmat([1,0]', 1, nb_topics);
- experts(:,2,1) = [-1, -4];
- experts(:,2,2) = [1, 0];
- experts(:,2,3) = [-1, 2];
- y = mnrnd(1,t);
- data = zeros(nb_samples, output_dim);
- for i=1:nb_samples
- data(i,:) = x(i,:)*experts(:,:,find(y(i,:)==1)) + mvnrnd(zeros(1,output_dim), 0.05*eye(output_dim));
- end
- real_T = T;
- real_expert = experts;
- %%%% TRAIN WITH EM %%%%
- %Initialization
- % T = rand(nb_topics, input_dim+1);
- % experts = -1+2*rand(input_dim + 1, output_dim, nb_topics);
- ll = zeros(1, nb_iter);
- %Training
- for iter=1:nb_iter
- %%%% EXPECTATION %%%%
- gating_values = softmax(T*x')';
- prediction_experts = zeros(nb_samples, nb_topics);
- for j=1:nb_topics
- prediction_experts(:,j) = mvnpdf(data, x*experts(:,:,j), 0.05*eye(output_dim)) + 10^(-3);
- end
- expectation = gating_values.*prediction_experts;
- expectation = expectation./repmat(sum(expectation, 2),1,nb_topics);
- %%%% MAXIMIZATION %%%%
- % for i=1:nb_topics
- % experts(:,:,i) = lscov(x, data, expectation(:,i));
- % end
- % T = lscov(x, log(expectation) + log(repmat(sum(exp(T*x'))', 1, nb_topics)))';
- % T = lscov(x, log(expectation))';
- %%% LOG-LIKELIHOOD %%%%
- ll(1, iter) = sum(log(sum(gating_values.*prediction_experts,2)));
- end
- %%%% PLOTTING %%%%
- clf();
- s1 = subplot(2,2,1);
- d1 = x*experts(:,:,1);
- d2 = x*experts(:,:,2);
- d3 = x*experts(:,:,3);
- plot(data(:,1),data(:,2),'kx', d1(:,1), d1(:,2), 'bx', d2(:,1), d2(:,2), 'rx', d3(:,1), d3(:,2), 'gx')
- subplot(2,2,2);
- tt = softmax(T*x')';
- plot(x(:,1), t(:,1), 'bx', x(:,1), t(:,2), 'rx', x(:,1), t(:,3), 'gx', x(:,1), tt(:,1), 'bo', x(:,1), tt(:,2), 'ro', x(:,1), tt(:,3), 'go')
- subplot(2,2,3);
- plot(x(:,1), prediction_experts(:,1), 'bx', x(:,1), prediction_experts(:,2), 'rx', x(:,1), prediction_experts(:,3), 'gx', x(:,1), gating_values(:,1), 'bo', x(:,1), gating_values(:,2), 'ro', x(:,1), gating_values(:,3), 'go')
- subplot(2,2,4);
- plot(x(:,1), expectation(:,1), 'bx', x(:,1), expectation(:,2), 'rx', x(:,1), expectation(:,3), 'gx')
- figure(2)
- plot(1:iter, ll)
Add Comment
Please, Sign In to add comment