Advertisement
t_ashpool

casbin adapter demo

Dec 14th, 2019
628
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Go 10.66 KB | None | 0 0
  1. // Copyright 2017 The casbin Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. //      http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14.  
  15. package gormadapter
  16.  
  17. import (
  18.     "errors"
  19.     "runtime"
  20.  
  21.     "github.com/casbin/casbin/v2/model"
  22.     "github.com/casbin/casbin/v2/persist"
  23.     "github.com/jinzhu/gorm"
  24.     "github.com/lib/pq"
  25. )
  26.  
  27. var tablePrefix string
  28.  
  29. type CasbinRule struct {
  30.     TablePrefix string `gorm:"-"`
  31.     PType       string `gorm:"size:100"`
  32.     V0          string `gorm:"size:100"`
  33.     V1          string `gorm:"size:100"`
  34.     V2          string `gorm:"size:100"`
  35.     V3          string `gorm:"size:100"`
  36.     V4          string `gorm:"size:100"`
  37.     V5          string `gorm:"size:100"`
  38. }
  39.  
  40. type Filter struct {
  41.     PType []string
  42.     V0    []string
  43.     V1    []string
  44.     V2    []string
  45.     V3    []string
  46.     V4    []string
  47.     V5    []string
  48. }
  49.  
  50. func (c *CasbinRule) TableName() string {
  51.     return c.TablePrefix + "casbin_rule" //as Gorm keeps table names are plural, and we love consistency
  52. }
  53.  
  54. // Adapter represents the Gorm adapter for policy storage.
  55. type Adapter struct {
  56.     tablePrefix    string
  57.     driverName     string
  58.     dataSourceName string
  59.     dbSpecified    bool
  60.     db             *gorm.DB
  61.     isFiltered     bool
  62. }
  63.  
  64. // finalizer is the destructor for Adapter.
  65. func finalizer(a *Adapter) {
  66.     err := a.db.Close()
  67.     if err != nil {
  68.         panic(err)
  69.     }
  70. }
  71.  
  72. // NewAdapter is the constructor for Adapter.
  73. // dbSpecified is an optional bool parameter. The default value is false.
  74. // It's up to whether you have specified an existing DB in dataSourceName.
  75. // If dbSpecified == true, you need to make sure the DB in dataSourceName exists.
  76. // If dbSpecified == false, the adapter will automatically create a DB named "casbin".
  77. func NewAdapter(driverName string, dataSourceName string, dbSpecified ...bool) (*Adapter, error) {
  78.     a := &Adapter{}
  79.     a.driverName = driverName
  80.     a.dataSourceName = dataSourceName
  81.  
  82.     if len(dbSpecified) == 0 {
  83.         a.dbSpecified = false
  84.     } else if len(dbSpecified) == 1 {
  85.         a.dbSpecified = dbSpecified[0]
  86.     } else {
  87.         return nil, errors.New("invalid parameter: dbSpecified")
  88.     }
  89.  
  90.     // Open the DB, create it if not existed.
  91.     err := a.open()
  92.     if err != nil {
  93.         return nil, err
  94.     }
  95.  
  96.     // Call the destructor when the object is released.
  97.     runtime.SetFinalizer(a, finalizer)
  98.  
  99.     return a, nil
  100. }
  101.  
  102. // NewAdapterByDB obtained through an existing Gorm instance get  a adapter, specify the table prefix
  103. // Example: gormadapter.NewAdapterByDBUsePrefix(&db, "cms_") Automatically generate table name like this "cms_casbin_rule"
  104. func NewAdapterByDBUsePrefix(db *gorm.DB, prefix string) (*Adapter, error) {
  105.     a := &Adapter{
  106.         tablePrefix: prefix,
  107.         db:          db,
  108.     }
  109.  
  110.     tablePrefix = prefix
  111.  
  112.     err := a.createTable()
  113.     if err != nil {
  114.         return nil, err
  115.     }
  116.  
  117.     return a, nil
  118. }
  119.  
  120. func NewAdapterByDB(db *gorm.DB) (*Adapter, error) {
  121.     a := &Adapter{
  122.         db: db,
  123.     }
  124.  
  125.     err := a.createTable()
  126.     if err != nil {
  127.         return nil, err
  128.     }
  129.  
  130.     return a, nil
  131. }
  132.  
  133. func (a *Adapter) createDatabase() error {
  134.     var err error
  135.     var db *gorm.DB
  136.     if a.driverName == "postgres" {
  137.         db, err = gorm.Open(a.driverName, a.dataSourceName+" dbname=postgres")
  138.     } else {
  139.         db, err = gorm.Open(a.driverName, a.dataSourceName)
  140.     }
  141.     if err != nil {
  142.         return err
  143.     }
  144.  
  145.     if a.driverName == "postgres" {
  146.         if err = db.Exec("CREATE DATABASE casbin").Error; err != nil {
  147.             // 42P04 is duplicate_database
  148.             if err.(*pq.Error).Code == "42P04" {
  149.                 db.Close()
  150.                 return nil
  151.             }
  152.         }
  153.     } else if a.driverName != "sqlite3" {
  154.         err = db.Exec("CREATE DATABASE IF NOT EXISTS casbin").Error
  155.     }
  156.     if err != nil {
  157.         db.Close()
  158.         return err
  159.     }
  160.  
  161.     return db.Close()
  162. }
  163.  
  164. func (a *Adapter) open() error {
  165.     var err error
  166.     var db *gorm.DB
  167.  
  168.     if a.dbSpecified {
  169.         db, err = gorm.Open(a.driverName, a.dataSourceName)
  170.         if err != nil {
  171.             return err
  172.         }
  173.     } else {
  174.         if err = a.createDatabase(); err != nil {
  175.             return err
  176.         }
  177.  
  178.         if a.driverName == "postgres" {
  179.             db, err = gorm.Open(a.driverName, a.dataSourceName+" dbname=casbin")
  180.         } else if a.driverName == "sqlite3" {
  181.             db, err = gorm.Open(a.driverName, a.dataSourceName)
  182.         } else {
  183.             db, err = gorm.Open(a.driverName, a.dataSourceName+"casbin")
  184.         }
  185.         if err != nil {
  186.             return err
  187.         }
  188.     }
  189.  
  190.     a.db = db
  191.  
  192.     return a.createTable()
  193. }
  194.  
  195. func (a *Adapter) close() error {
  196.     err := a.db.Close()
  197.     if err != nil {
  198.         return err
  199.     }
  200.  
  201.     a.db = nil
  202.     return nil
  203. }
  204.  
  205. // getTableInstance return the dynamic table name
  206. func (a *Adapter) getTableInstance() *CasbinRule {
  207.     return &CasbinRule{TablePrefix: a.tablePrefix}
  208. }
  209.  
  210. func (a *Adapter) createTable() error {
  211.     if a.db.HasTable(a.getTableInstance()) {
  212.         return nil
  213.     }
  214.  
  215.     return a.db.CreateTable(a.getTableInstance()).Error
  216. }
  217.  
  218. func (a *Adapter) dropTable() error {
  219.     return a.db.DropTable(a.getTableInstance()).Error
  220. }
  221.  
  222. func loadPolicyLine(line CasbinRule, model model.Model) {
  223.     lineText := line.PType
  224.     if line.V0 != "" {
  225.         lineText += ", " + line.V0
  226.     }
  227.     if line.V1 != "" {
  228.         lineText += ", " + line.V1
  229.     }
  230.     if line.V2 != "" {
  231.         lineText += ", " + line.V2
  232.     }
  233.     if line.V3 != "" {
  234.         lineText += ", " + line.V3
  235.     }
  236.     if line.V4 != "" {
  237.         lineText += ", " + line.V4
  238.     }
  239.     if line.V5 != "" {
  240.         lineText += ", " + line.V5
  241.     }
  242.  
  243.     persist.LoadPolicyLine(lineText, model)
  244. }
  245.  
  246. // LoadPolicy loads policy from database.
  247. func (a *Adapter) LoadPolicy(model model.Model) error {
  248.     var lines []CasbinRule
  249.     if err := a.db.Table(a.tablePrefix + "casbin_rule").Find(&lines).Error; err != nil {
  250.         return err
  251.     }
  252.  
  253.     for _, line := range lines {
  254.         loadPolicyLine(line, model)
  255.     }
  256.  
  257.     return nil
  258. }
  259.  
  260. // LoadFilteredPolicy loads only policy rules that match the filter.
  261. func (a *Adapter) LoadFilteredPolicy(model model.Model, filter interface{}) error {
  262.     var lines []CasbinRule
  263.  
  264.     filterValue, ok := filter.(Filter)
  265.     if !ok {
  266.         return errors.New("invalid filter type")
  267.     }
  268.  
  269.     if err := a.db.Scopes(a.filterQuery(a.db, filterValue)).Find(&lines).Error; err != nil {
  270.         return err
  271.     }
  272.  
  273.     for _, line := range lines {
  274.         loadPolicyLine(line, model)
  275.     }
  276.     a.isFiltered = true
  277.  
  278.     return nil
  279. }
  280.  
  281. // IsFiltered returns true if the loaded policy has been filtered.
  282. func (a *Adapter) IsFiltered() bool {
  283.     return a.isFiltered
  284. }
  285.  
  286. // filterQuery builds the gorm query to match the rule filter to use within a scope.
  287. func (a *Adapter) filterQuery(db *gorm.DB, filter Filter) func(db *gorm.DB) *gorm.DB {
  288.     return func(db *gorm.DB) *gorm.DB {
  289.         if len(filter.PType) > 0 {
  290.             db = db.Where("p_type in (?)", filter.PType)
  291.         }
  292.         if len(filter.V0) > 0 {
  293.             db = db.Where("v0 in (?)", filter.V0)
  294.         }
  295.         if len(filter.V1) > 0 {
  296.             db = db.Where("v1 in (?)", filter.V1)
  297.         }
  298.         if len(filter.V2) > 0 {
  299.             db = db.Where("v2 in (?)", filter.V2)
  300.         }
  301.         if len(filter.V3) > 0 {
  302.             db = db.Where("v3 in (?)", filter.V3)
  303.         }
  304.         if len(filter.V4) > 0 {
  305.             db = db.Where("v4 in (?)", filter.V4)
  306.         }
  307.         if len(filter.V5) > 0 {
  308.             db = db.Where("v5 in (?)", filter.V5)
  309.         }
  310.         return db
  311.     }
  312. }
  313.  
  314. func (a *Adapter) savePolicyLine(ptype string, rule []string) CasbinRule {
  315.     line := a.getTableInstance()
  316.  
  317.     line.PType = ptype
  318.     if len(rule) > 0 {
  319.         line.V0 = rule[0]
  320.     }
  321.     if len(rule) > 1 {
  322.         line.V1 = rule[1]
  323.     }
  324.     if len(rule) > 2 {
  325.         line.V2 = rule[2]
  326.     }
  327.     if len(rule) > 3 {
  328.         line.V3 = rule[3]
  329.     }
  330.     if len(rule) > 4 {
  331.         line.V4 = rule[4]
  332.     }
  333.     if len(rule) > 5 {
  334.         line.V5 = rule[5]
  335.     }
  336.  
  337.     return *line
  338. }
  339.  
  340. // SavePolicy saves policy to database.
  341. func (a *Adapter) SavePolicy(model model.Model) error {
  342.     err := a.dropTable()
  343.     if err != nil {
  344.         return err
  345.     }
  346.     err = a.createTable()
  347.     if err != nil {
  348.         return err
  349.     }
  350.  
  351.     for ptype, ast := range model["p"] {
  352.         for _, rule := range ast.Policy {
  353.             line := a.savePolicyLine(ptype, rule)
  354.             err := a.db.Create(&line).Error
  355.             if err != nil {
  356.                 return err
  357.             }
  358.         }
  359.     }
  360.  
  361.     for ptype, ast := range model["g"] {
  362.         for _, rule := range ast.Policy {
  363.             line := a.savePolicyLine(ptype, rule)
  364.             err := a.db.Create(&line).Error
  365.             if err != nil {
  366.                 return err
  367.             }
  368.         }
  369.     }
  370.  
  371.     return nil
  372. }
  373.  
  374. // AddPolicy adds a policy rule to the storage.
  375. func (a *Adapter) AddPolicy(sec string, ptype string, rule []string) error {
  376.     line := a.savePolicyLine(ptype, rule)
  377.     err := a.db.Create(&line).Error
  378.     return err
  379. }
  380.  
  381. // RemovePolicy removes a policy rule from the storage.
  382. func (a *Adapter) RemovePolicy(sec string, ptype string, rule []string) error {
  383.     line := a.savePolicyLine(ptype, rule)
  384.     err := a.rawDelete(a.db, line) //can't use db.Delete as we're not using primary key http://jinzhu.me/gorm/crud.html#delete
  385.     return err
  386. }
  387.  
  388. // RemoveFilteredPolicy removes policy rules that match the filter from the storage.
  389. func (a *Adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {
  390.     line := a.getTableInstance()
  391.  
  392.     line.PType = ptype
  393.     if fieldIndex <= 0 && 0 < fieldIndex+len(fieldValues) {
  394.         line.V0 = fieldValues[0-fieldIndex]
  395.     }
  396.     if fieldIndex <= 1 && 1 < fieldIndex+len(fieldValues) {
  397.         line.V1 = fieldValues[1-fieldIndex]
  398.     }
  399.     if fieldIndex <= 2 && 2 < fieldIndex+len(fieldValues) {
  400.         line.V2 = fieldValues[2-fieldIndex]
  401.     }
  402.     if fieldIndex <= 3 && 3 < fieldIndex+len(fieldValues) {
  403.         line.V3 = fieldValues[3-fieldIndex]
  404.     }
  405.     if fieldIndex <= 4 && 4 < fieldIndex+len(fieldValues) {
  406.         line.V4 = fieldValues[4-fieldIndex]
  407.     }
  408.     if fieldIndex <= 5 && 5 < fieldIndex+len(fieldValues) {
  409.         line.V5 = fieldValues[5-fieldIndex]
  410.     }
  411.     err := a.rawDelete(a.db, *line)
  412.     return err
  413. }
  414.  
  415. func (a *Adapter) rawDelete(db *gorm.DB, line CasbinRule) error {
  416.     queryArgs := []interface{}{line.PType}
  417.  
  418.     queryStr := "p_type = ?"
  419.     if line.V0 != "" {
  420.         queryStr += " and v0 = ?"
  421.         queryArgs = append(queryArgs, line.V0)
  422.     }
  423.     if line.V1 != "" {
  424.         queryStr += " and v1 = ?"
  425.         queryArgs = append(queryArgs, line.V1)
  426.     }
  427.     if line.V2 != "" {
  428.         queryStr += " and v2 = ?"
  429.         queryArgs = append(queryArgs, line.V2)
  430.     }
  431.     if line.V3 != "" {
  432.         queryStr += " and v3 = ?"
  433.         queryArgs = append(queryArgs, line.V3)
  434.     }
  435.     if line.V4 != "" {
  436.         queryStr += " and v4 = ?"
  437.         queryArgs = append(queryArgs, line.V4)
  438.     }
  439.     if line.V5 != "" {
  440.         queryStr += " and v5 = ?"
  441.         queryArgs = append(queryArgs, line.V5)
  442.     }
  443.     args := append([]interface{}{queryStr}, queryArgs...)
  444.     err := db.Delete(a.getTableInstance(), args...).Error
  445.     return err
  446. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement