Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- int main(int argc, char **argv) {
- const int terms = 8;
- Var x("x"), y("y"), i("i");
- Param<int> width;
- Param<int> height;
- Param<float> mu;
- Param<float> beta;
- Param<float> sigma;
- Param<float> gamma;
- Param<float> tau;
- Param<int> order, samples;
- Param<double> learning_rate;
- RDom s(0, width, 0, height);
- s.where(abs(sqrt(s.x * s.x + s.y * s.y) - mu) < 3 * sigma);
- // int n = 5; //number of features
- ImageParam disparities(Float(64), 2);
- ImageParam left_features(Float(64), 3);
- ImageParam right_features(Float(64), 3);
- ImageParam a(Float(64), 1);
- ImageParam b(Float(64), 1);
- ImageParam c(Float(64), 1);
- ImageParam point_tri_map(Int(32), 2);
- Func mean;
- mean(x, y) = a(point_tri_map(x, y)) * x + b(point_tri_map(x, y)) * y + c(point_tri_map(x, y));
- Func diff;
- diff(x, y, i) = left_features(x, y, i) - right_features(x, y, i); // for i = 0 to n
- Func norm;
- RDom n_feat(0, 5);
- norm(x, y) = sum(abs(diff(x, y, n_feat)));
- ImageParam d(Int(32), 2);
- Func energy;
- energy(x, y) = beta * norm(x - d(x, y), y) - log (gamma + exp(- pow((d(x, y) - mean(x, y)), 2)) / (2 * sigma * sigma));
- RDom img(0, width, 0, height);
- Func err;
- // err() = f64(0);
- err() = sum(energy(img.x, img.y));
- auto d_err_d = propagate_adjoints(err);
- Func new_d;
- new_d(x, y) = d(x, y) - learning_rate * d_err_d(d)(x, y);
- Pipeline p({err, new_d});
- p.compile_to_static_library("energy", {disparities, a, b, c, point_tri_map, left_features, right_features, d}, "energy");
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement