Advertisement
Guest User

Untitled

a guest
Aug 23rd, 2019
74
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 11.53 KB | None | 0 0
  1. package ride
  2.  
  3. import (
  4. "context"
  5. "database/sql"
  6. "errors"
  7. "fmt"
  8. "github.com/huandu/go-sqlbuilder"
  9. "github.com/lib/pq"
  10. uuid "github.com/satori/go.uuid"
  11. "github.com/yanishoss/germainion/internal/core"
  12. "time"
  13. )
  14.  
  15. var (
  16. ErrRideDoesntExist = errors.New("there is no ride with these information")
  17. )
  18.  
  19. type manager struct {
  20. db *sql.Tx
  21. }
  22.  
  23. // Manager represents a user pool
  24. type Manager interface {
  25. Create(ctx context.Context, driverID core.ID, ride Ride) (core.ID, error)
  26. UpdateByID(ctx context.Context, id core.ID, ride Ride) error
  27. DeleteByID(ctx context.Context, id core.ID) error
  28. DeleteAllByDriverID(ctx context.Context, driverID core.ID) error
  29. GetByID(ctx context.Context, id core.ID) (*Entry, error)
  30. GetAllByDriverID(ctx context.Context, driverID core.ID, limit int, offset int) (Entries, error)
  31. ExistsByID(ctx context.Context, id core.ID) (bool, error)
  32. ExistsByDriverID(ctx context.Context, driverID core.ID) (bool, error)
  33. ExistsByRideIDAndDriverID(ctx context.Context, rideID core.ID, driverID core.ID) (bool, error)
  34. FindBestMatch(ctx context.Context, startLocation, endLocation Location, startTime time.Time, seats, limit, offset int) (Entries, error)
  35. Expire(ctx context.Context) error
  36. }
  37.  
  38. // New creates a new Manager (a singleton)
  39. func New(tx *sql.Tx) Manager {
  40. m := &manager{
  41. db: tx,
  42. }
  43.  
  44. m.init()
  45.  
  46. return m
  47. }
  48.  
  49. func (m manager) init() {
  50. db := m.db
  51.  
  52. _, err := db.Exec(`CREATE EXTENSION IF NOT EXISTS postgis;`)
  53.  
  54. if err != nil {
  55. panic(err)
  56. }
  57.  
  58. // Got to remove the foreign key to the "user_pool" table
  59. _, err = db.Exec(`CREATE TABLE IF NOT EXISTS ride (
  60. id UUID PRIMARY KEY NOT NULL,
  61. driver_id UUID NOT NULL,
  62. seats INT NOT NULL,
  63. start_location GEOGRAPHY NOT NULL,
  64. end_location GEOGRAPHY NOT NULL,
  65. start_time TIMESTAMP WITH TIME ZONE NOT NULL,
  66. expired BOOLEAN NOT NULL DEFAULT false,
  67. created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(),
  68. updated_at TIMESTAMP WITH TIME ZONE DEFAULT now()
  69. );`)
  70.  
  71. if err != nil {
  72. panic(err)
  73. }
  74. }
  75.  
  76. func (m manager) Create(ctx context.Context, driverID core.ID, ride Ride) (core.ID, error) {
  77. db := m.db
  78.  
  79. id := uuid.NewV4().String()
  80.  
  81. q := `INSERT INTO ride (id, driver_id, seats, start_location, end_location, start_time, expired) VALUES ($1, $2, $3, ST_MakePoint($4, $5)::GEOGRAPHY, ST_MakePoint($6, $7)::GEOGRAPHY, $8, $9)`
  82.  
  83. _, err := db.ExecContext(
  84. ctx,
  85. q,
  86. id,
  87. driverID,
  88. *ride.Seats,
  89. ride.StartLocation[0],
  90. ride.StartLocation[1],
  91. ride.EndLocation[0],
  92. ride.EndLocation[1],
  93. pq.FormatTimestamp(*ride.StartTime),
  94. false)
  95.  
  96. if err != nil {
  97. return "", err
  98. }
  99.  
  100. return id, nil
  101. }
  102.  
  103. func (m manager) UpdateByID(ctx context.Context, id core.ID, ride Ride) error {
  104. db := m.db
  105.  
  106. oldRide, err := m.GetByID(ctx, id)
  107.  
  108. if err != nil {
  109. return err
  110. }
  111.  
  112. ub := sqlbuilder.NewUpdateBuilder()
  113.  
  114. ub.
  115. Update("ride").
  116. Set("updated_at = now()")
  117.  
  118. if ride.Seats != nil && (*ride.Seats != *oldRide.Seats) {
  119. ub.Set("seats = " + ub.Var(*ride.Seats))
  120. }
  121.  
  122. if ride.StartLocation != nil && (ride.StartLocation[0] != 0 || ride.StartLocation[1] != 0) && (ride.StartLocation[0] != oldRide.StartLocation[0] || ride.StartLocation[1] != oldRide.StartLocation[1]) {
  123. ub.Set("start_location = " + ub.Var(*ride.StartLocation))
  124. }
  125.  
  126. if ride.EndLocation != nil && (ride.EndLocation[0] != 0 || ride.EndLocation[1] != 0) && (ride.EndLocation[0] != oldRide.EndLocation[0] || ride.EndLocation[1] != oldRide.EndLocation[1]) {
  127. ub.Set("end_location = " + ub.Var(*ride.StartLocation))
  128. }
  129.  
  130. if ride.StartTime != nil && !ride.StartTime.Equal(*oldRide.StartTime) {
  131. ub.Set("start_time = " + ub.Var(pq.FormatTimestamp(*ride.StartTime)))
  132. }
  133.  
  134. if ride.Expired != nil && (*ride.Expired != *oldRide.Expired) {
  135. ub.Set("expired = " + ub.Var(*ride.Expired))
  136. }
  137.  
  138. ub.Where("id = " + ub.Var(id))
  139.  
  140. q, args := ub.BuildWithFlavor(sqlbuilder.PostgreSQL)
  141.  
  142. _, err = db.ExecContext(ctx, q, args...)
  143.  
  144. if err != nil {
  145. return err
  146. }
  147.  
  148. return nil
  149. }
  150.  
  151. func (m manager) GetByID(ctx context.Context, id core.ID) (*Entry, error) {
  152. db := m.db
  153.  
  154. exists, err := m.ExistsByID(ctx, id)
  155.  
  156. if err != nil {
  157. return nil, err
  158. }
  159.  
  160. if !exists {
  161. return nil, ErrRideDoesntExist
  162. }
  163.  
  164. q := `
  165. SELECT
  166. id,
  167. driver_id,
  168. seats,
  169. r.seats - COALESCE(b.total_seats, 0) AS seats_left,
  170. ST_X(start_location::GEOMETRY),
  171. ST_Y(start_location::GEOMETRY),
  172. ST_X(end_location::GEOMETRY),
  173. ST_Y(end_location::GEOMETRY),
  174. start_time,
  175. expired,
  176. created_at,
  177. updated_at
  178. FROM ride r
  179. LEFT JOIN (
  180. SELECT
  181. b.ride_id,
  182. SUM(b.seats) AS total_seats
  183. FROM booking b
  184. GROUP BY b.ride_id
  185. ) b ON r.id = b.ride_id
  186. WHERE id = $1
  187. GROUP BY r.id, b.ride_id, b.total_seats
  188. `
  189.  
  190. row := db.QueryRowContext(ctx, q, id)
  191.  
  192. entry := NewEntry()
  193.  
  194. if err := row.Scan(
  195. entry.ID,
  196. entry.DriverID,
  197. entry.Seats,
  198. entry.SeatsLeft,
  199. &entry.StartLocation[0],
  200. &entry.StartLocation[1],
  201. &entry.EndLocation[0],
  202. &entry.EndLocation[1],
  203. entry.StartTime,
  204. entry.Expired,
  205. entry.CreatedAt,
  206. entry.UpdatedAt,
  207. ); err != nil {
  208. fmt.Println(err)
  209. return nil, err
  210. }
  211.  
  212. return entry, nil
  213. }
  214.  
  215. func (m manager) ExistsByID(ctx context.Context, id core.ID) (bool, error) {
  216. db := m.db
  217.  
  218. sb := sqlbuilder.NewSelectBuilder()
  219.  
  220. q, args := sb.
  221. Select("COUNT(*)").
  222. From("ride").
  223. Where("id = " + sb.Var(id)).
  224. Limit(1).
  225. BuildWithFlavor(sqlbuilder.PostgreSQL)
  226.  
  227. count := 0
  228.  
  229. row := db.QueryRowContext(ctx, q, args...)
  230.  
  231. if err := row.Scan(&count); err != nil {
  232. return false, err
  233. }
  234.  
  235. return count > 0, nil
  236. }
  237.  
  238. func (m manager) ExistsByDriverID(ctx context.Context, driverID core.ID) (bool, error) {
  239. db := m.db
  240.  
  241. sb := sqlbuilder.NewSelectBuilder()
  242.  
  243. q, args := sb.
  244. Select("COUNT(*)").
  245. From("ride").
  246. Where("driver_id = " + sb.Var(driverID)).
  247. Limit(1).
  248. BuildWithFlavor(sqlbuilder.PostgreSQL)
  249.  
  250. count := 0
  251.  
  252. row := db.QueryRowContext(ctx, q, args...)
  253.  
  254. if err := row.Scan(&count); err != nil {
  255. return false, err
  256. }
  257.  
  258. return count > 0, nil
  259. }
  260.  
  261. func (m manager) DeleteByID(ctx context.Context, id core.ID) error {
  262. db := m.db
  263.  
  264. exists, err := m.ExistsByID(ctx, id)
  265.  
  266. if err != nil {
  267. return err
  268. }
  269.  
  270. if !exists {
  271. return ErrRideDoesntExist
  272. }
  273.  
  274. delb := sqlbuilder.NewDeleteBuilder()
  275.  
  276. q, args := delb.
  277. DeleteFrom("ride").
  278. Where("id = " + delb.Var(id)).
  279. BuildWithFlavor(sqlbuilder.PostgreSQL)
  280.  
  281. _, err = db.ExecContext(ctx, q, args...)
  282.  
  283. if err != nil {
  284. return err
  285. }
  286.  
  287. return nil
  288. }
  289.  
  290. func (m manager) GetAllByDriverID(ctx context.Context, driverID core.ID, limit, offset int) (Entries, error) {
  291. db := m.db
  292.  
  293. exists, err := m.ExistsByDriverID(ctx, driverID)
  294.  
  295. if err != nil {
  296. return nil, err
  297. }
  298.  
  299. if !exists {
  300. return nil, ErrRideDoesntExist
  301. }
  302.  
  303. q := `
  304. SELECT
  305. id,
  306. driver_id,
  307. seats,
  308. r.seats - COALESCE(b.total_seats, 0) AS seats_left,
  309. ST_X(start_location::GEOMETRY),
  310. ST_Y(start_location::GEOMETRY),
  311. ST_X(end_location::GEOMETRY),
  312. ST_Y(end_location::GEOMETRY),
  313. start_time,
  314. expired,
  315. created_at,
  316. updated_at
  317. FROM ride r
  318. LEFT JOIN (
  319. SELECT
  320. b.ride_id,
  321. SUM(b.seats) AS total_seats
  322. FROM booking b
  323. GROUP BY b.ride_id
  324. ) b ON r.id = b.ride_id
  325. WHERE driver_id = $1
  326. GROUP BY r.id, b.ride_id, b.total_seats
  327. ORDER BY created_at DESC
  328. `
  329.  
  330. if limit == core.DefaultLimit {
  331. q += "\nLIMIT ALL"
  332. } else {
  333. q += fmt.Sprintf("\nLIMIT %d", limit)
  334. }
  335.  
  336. if offset == core.DefaultOffset {
  337. q += "\nOFFSET 0"
  338. } else {
  339. q += fmt.Sprintf("\nOFFSET %d", offset)
  340. }
  341.  
  342. rows, err := db.QueryContext(ctx, q, driverID)
  343.  
  344. if err != nil {
  345. return nil, err
  346. }
  347.  
  348. defer rows.Close()
  349.  
  350. entries := make(Entries, 0, core.DefaultScanAllocation)
  351.  
  352. for rows.Next() {
  353. entry := NewEntry()
  354.  
  355. if err := rows.Scan(
  356. entry.ID,
  357. entry.DriverID,
  358. entry.Seats,
  359. entry.SeatsLeft,
  360. &entry.StartLocation[0],
  361. &entry.StartLocation[1],
  362. &entry.EndLocation[0],
  363. &entry.EndLocation[1],
  364. entry.StartTime,
  365. entry.Expired,
  366. entry.CreatedAt,
  367. entry.UpdatedAt,
  368. ); err != nil {
  369. return nil, err
  370. }
  371.  
  372. entries = append(entries, entry)
  373. }
  374.  
  375. return entries, nil
  376. }
  377.  
  378. func (m manager) FindBestMatch(ctx context.Context, startLocation, endLocation Location, startTime time.Time, seats, limit, offset int) (Entries, error) {
  379. db := m.db
  380.  
  381. q := `
  382. SELECT
  383. id,
  384. driver_id,
  385. seats,
  386. seats_left,
  387. ST_X(start_location::GEOMETRY),
  388. ST_Y(start_location::GEOMETRY),
  389. ST_X(end_location::GEOMETRY),
  390. ST_Y(end_location::GEOMETRY),
  391. start_time,
  392. expired,
  393. created_at,
  394. updated_at
  395. FROM (
  396. SELECT
  397. *,
  398. ST_DistanceSphere(r.start_location::GEOMETRY, ST_MakePoint($1, $2)) AS start_dist,
  399. ST_DistanceSphere(r.end_location::GEOMETRY, ST_MakePoint($3, $4)) AS end_dist,
  400. r.seats - COALESCE(b.total_seats, 0) AS seats_left
  401. FROM ride r
  402. LEFT JOIN (
  403. SELECT
  404. b.ride_id,
  405. SUM(b.seats) AS total_seats
  406. FROM booking b
  407. GROUP BY b.ride_id
  408. ) b ON r.id = b.ride_id
  409. WHERE r.start_time >= $5::TIMESTAMPTZ
  410. AND r.start_time <= $5::TIMESTAMPTZ + '20 minute'::INTERVAL
  411. AND NOT r.expired
  412. GROUP BY r.id, b.ride_id, b.total_seats
  413. ) AS q
  414. WHERE (start_dist + end_dist) <= 1000
  415. AND seats_left >= $6
  416. ORDER BY
  417. start_time ASC,
  418. (start_dist + end_dist) ASC,
  419. seats_left DESC
  420. `
  421.  
  422. // Awkwardly try to handle null limit and offset
  423. if limit == core.DefaultLimit {
  424. q += "\nLIMIT ALL"
  425. } else {
  426. q += fmt.Sprintf("\nLIMIT %d", limit)
  427. }
  428.  
  429. if offset == core.DefaultOffset {
  430. q += "\nOFFSET 0"
  431. } else {
  432. q += fmt.Sprintf("\nOFFSET %d", offset)
  433. }
  434.  
  435. rows, err := db.QueryContext(
  436. ctx,
  437. q,
  438. startLocation[0],
  439. startLocation[1],
  440. endLocation[0],
  441. endLocation[1],
  442. startTime,
  443. seats,
  444. )
  445.  
  446. if err != nil {
  447. return nil, err
  448. }
  449.  
  450. defer rows.Close()
  451.  
  452. entries := make(Entries, 0, core.DefaultScanAllocation)
  453.  
  454. for rows.Next() {
  455. entry := NewEntry()
  456.  
  457. if err := rows.Scan(
  458. entry.ID,
  459. entry.DriverID,
  460. entry.Seats,
  461. entry.SeatsLeft,
  462. &entry.StartLocation[0],
  463. &entry.StartLocation[1],
  464. &entry.EndLocation[0],
  465. &entry.EndLocation[1],
  466. entry.StartTime,
  467. entry.Expired,
  468. entry.CreatedAt,
  469. entry.UpdatedAt,
  470. ); err != nil {
  471. return nil, err
  472. }
  473.  
  474. entries = append(entries, entry)
  475. }
  476.  
  477. return entries, nil
  478. }
  479.  
  480. func (m manager) DeleteAllByDriverID(ctx context.Context, driverID core.ID) error {
  481. db := m.db
  482.  
  483. exists, err := m.ExistsByDriverID(ctx, driverID)
  484.  
  485. if err != nil {
  486. return err
  487. }
  488.  
  489. if !exists {
  490. return ErrRideDoesntExist
  491. }
  492.  
  493. delb := sqlbuilder.NewDeleteBuilder()
  494.  
  495. q, args := delb.
  496. DeleteFrom("ride").
  497. Where("driver_id = " + delb.Var(driverID)).
  498. BuildWithFlavor(sqlbuilder.PostgreSQL)
  499.  
  500. _, err = db.ExecContext(ctx, q, args...)
  501.  
  502. if err != nil {
  503. return err
  504. }
  505.  
  506. return nil
  507. }
  508.  
  509. func (m manager) Expire(ctx context.Context) error {
  510. db := m.db
  511.  
  512. q := `
  513. UPDATE ride
  514. SET expired = true
  515. WHERE
  516. start_time <= now()
  517. AND NOT expired;
  518. `
  519.  
  520. _, err := db.ExecContext(ctx, q)
  521.  
  522. if err != nil {
  523. return err
  524. }
  525.  
  526. return nil
  527. }
  528.  
  529. func (m manager) ExistsByRideIDAndDriverID(ctx context.Context, rideID core.ID, driverID core.ID) (bool, error) {
  530. db := m.db
  531.  
  532. sb := sqlbuilder.NewSelectBuilder()
  533.  
  534. q, args := sb.
  535. Select("COUNT(*)").
  536. From("ride").
  537. Where("driver_id = "+sb.Var(driverID), "id = "+sb.Var(rideID)).
  538. Limit(1).
  539. BuildWithFlavor(sqlbuilder.PostgreSQL)
  540.  
  541. count := 0
  542.  
  543. row := db.QueryRowContext(ctx, q, args...)
  544.  
  545. if err := row.Scan(&count); err != nil {
  546. return false, err
  547. }
  548.  
  549. return count > 0, nil
  550. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement