Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package main
- import (
- "flag"
- "fmt"
- "io"
- "os"
- "os/exec"
- "sort"
- "strings"
- "time"
- )
- func init() {
- sort.Strings(builtins)
- }
- func main() {
- lex := flag.Bool("lex", false, "lex")
- flag.Parse()
- data := []string{
- "xxh -k 64 scripts/basic.go scripts/quote.go",
- "find tmp/ -type f -name *txt",
- "find /tmp -type f ; xxh -k 64 scripts/quote.go",
- "find /tmp -type f || xxh -k 64 /tmp",
- "find /tmp -type f && xxh -k 64 /tmp",
- "echo foobar; find /tmp -type f && xxh -k 64 /tmp",
- "echo foobar#a comment",
- "find tmp/ -maxdepth 1 -type f | xxh",
- "NAME=foobar",
- "NAME=foobar;",
- "NAME=",
- "NAME=;",
- "NAME=foo bar baz",
- "$VAR",
- "${VAR}",
- "$((1+35+1))",
- "$((1 + (7*5) + 1))",
- "$((1 + (7 * VAR) + 1))",
- "$((1 + (7 * $VAR) + 1))",
- }
- if *lex {
- lexOnly(data)
- } else {
- parseOnly(data)
- }
- }
- func parseOnly(data []string) {
- for i, d := range data {
- if i > 0 {
- fmt.Println("---")
- }
- fmt.Println("parsing:", d)
- if c, err := NewParser(d).Parse(); err != nil {
- fmt.Fprintln(os.Stderr, err)
- } else {
- fmt.Println(c.String())
- }
- }
- }
- func lexOnly(data []string) {
- for i, d := range data {
- if i > 0 {
- fmt.Println("---")
- }
- fmt.Println("lexing:", d)
- x := Lex(d)
- for t := x.Next(); t.Type != EOF; t = x.Next() {
- fmt.Println(">", t)
- }
- }
- }
- type shell struct {
- globals map[string]string
- locals map[string]string
- level uint64
- pwd string // previous working directory
- cwd string // current working directory
- home string // home directory
- now time.Time
- }
- func NewShell() *shell {
- x := shell{
- globals: make(map[string]string),
- locals: make(map[string]string),
- now: time.Now(),
- cwd: os.TempDir(),
- pwd: os.TempDir(),
- home: os.TempDir(),
- }
- x.level++
- return &x
- }
- func (s *shell) Exec() error {
- return s.execCommand(nil)
- }
- func (s *shell) execCommand(e Exec) error {
- var err error
- switch e.(type) {
- case sequence:
- case pipeline:
- case logical:
- case command:
- default:
- err = fmt.Errorf("unsupported syntax %T", e)
- }
- return err
- }
- func (s *shell) Subshell() *shell {
- x := NewShell()
- x.level = s.level + 1
- for k, v := range s.globals {
- x.globals[k] = v
- }
- for k, v := range s.locals {
- x.locals[k] = v
- }
- return x
- }
- const (
- bindLowest = iota
- bindEqual
- bindShift
- bindPlus
- bindMultiply
- bindExponent
- bindNegation
- bindUnary
- bindLogical
- bindPipe
- )
- var powers = map[int]int{
- Background: bindLowest,
- Sequence: bindLowest,
- And: bindLogical,
- Or: bindLogical,
- Pipeline: bindPipe,
- }
- type Exec interface {
- Exec() error
- Wait() error
- fmt.Stringer
- }
- type literal string
- func (i literal) String() string {
- return string(i)
- }
- func (i literal) Exec() error { return nil }
- func (i literal) Wait() error { return nil }
- type command struct {
- cmd *exec.Cmd
- async bool
- stdout string
- stderr string
- stdin string
- }
- func newCommand(as []string) *command {
- var c command
- c.cmd = exec.Command(as[0], as[1:]...)
- return &c
- }
- func (c command) String() string {
- var b strings.Builder
- b.WriteString(c.cmd.Args[0])
- if len(c.cmd.Args) > 1 {
- for _, a := range c.cmd.Args[1:] {
- b.WriteRune(space)
- b.WriteString(a)
- }
- }
- return b.String()
- }
- func (c command) Exec() error {
- if c.cmd.Stdout == nil {
- c.cmd.Stdout = os.Stdout
- }
- if c.cmd.Stderr == nil {
- c.cmd.Stderr = os.Stderr
- }
- return c.cmd.Start()
- }
- func (c command) Wait() error {
- err := c.Exec()
- if err == nil {
- err = c.cmd.Wait()
- }
- return err
- }
- type sequence []Exec
- func (es sequence) String() string {
- var b strings.Builder
- for i, e := range es {
- if i >= 1 {
- b.WriteRune(space)
- }
- b.WriteString(e.String())
- b.WriteRune(semicolon)
- }
- return b.String()
- }
- func (es sequence) Exec() error {
- return es.Wait()
- }
- func (es sequence) Wait() error {
- var err error
- for _, e := range es {
- err = e.Wait()
- }
- return err
- }
- type logical struct {
- op int
- left Exec
- right Exec
- }
- func (i logical) String() string {
- var op string
- switch i.op {
- case And:
- op = "&&"
- case Or:
- op = "||"
- default:
- op = "???"
- }
- var b strings.Builder
- b.WriteString(i.left.String())
- b.WriteRune(space)
- b.WriteString(op)
- b.WriteRune(space)
- b.WriteString(i.right.String())
- return b.String()
- }
- func (i logical) Exec() error {
- return i.Wait()
- }
- func (i logical) Wait() error {
- var fn func(Exec, Exec) error
- switch i.op {
- case Or:
- fn = evalOr
- case And:
- fn = evalAnd
- default:
- }
- if fn == nil {
- return fmt.Errorf("unkown logical operator")
- }
- return fn(i.left, i.right)
- }
- func evalOr(left, right Exec) error {
- err := left.Wait()
- if err != nil {
- err = right.Wait()
- }
- return err
- }
- func evalAnd(left, right Exec) error {
- err := left.Wait()
- if err == nil {
- err = right.Wait()
- }
- return err
- }
- type pipeline []Exec
- func (p pipeline) Exec() error {
- return p.Wait()
- }
- func (p pipeline) Wait() error {
- var pipe io.Reader = os.Stdin
- for i, e := range p {
- if c, ok := e.(*command); ok {
- c.cmd.Stdin = pipe
- pipe, _ = c.cmd.StdoutPipe()
- } else {
- return fmt.Errorf("invalid construct")
- }
- if i < len(p)-1 {
- go e.Wait()
- }
- }
- go io.Copy(os.Stdout, pipe)
- return p[len(p)-1].Wait()
- }
- func (p pipeline) String() string {
- var b strings.Builder
- for i, e := range p {
- if i >= 1 {
- b.WriteRune(space)
- b.WriteRune(pipe)
- b.WriteRune(space)
- }
- b.WriteString(e.String())
- }
- return b.String()
- }
- type assignment struct {
- ident string
- value Exec
- }
- func (a assignment) Exec() error {
- return nil
- }
- func (a assignment) Wait() error {
- return nil
- }
- func (a assignment) String() string {
- var b strings.Builder
- b.WriteString(a.ident)
- b.WriteRune(equal)
- if a.value != nil {
- b.WriteString(a.value.String())
- }
- return b.String()
- }
- type arithmetic struct {
- expr Expression
- }
- func (a arithmetic) String() string {
- return a.expr.String()
- }
- func (a arithmetic) Exec() error {
- return nil
- }
- func (a arithmetic) Wait() error {
- return nil
- }
- type Expression interface {
- fmt.Stringer
- }
- type unary struct {
- op byte
- right Expression
- }
- type infix struct {
- op byte
- left Expression
- right Expression
- }
- var builtins = []string{
- "echo",
- "cd",
- "pwd",
- "now",
- "export",
- "true",
- "false",
- "wait",
- }
- type parser struct {
- lex *lexer
- // prefix infix function to parse shell construct
- prefix map[int]func() (Exec, error)
- infix map[int]func(Exec) (Exec, error)
- curr Token
- peek Token
- }
- type mathparser struct {
- lex *lexer
- prefix map[byte]func() (Expression, error)
- infix map[byte]func(Expression) (Expression, error)
- curr Token
- peek Token
- }
- func parseMath(lex *lexer) (Expression, error) {
- var p mathparser
- p.init(lex)
- return p.parseExpression(bindLowest)
- }
- func (p *mathparser) init(lex *lexer) {
- p.lex = lex
- p.prefix = map[byte]func() (Expression, error){}
- p.infix = map[byte]func(Expression) (Expression, error){}
- p.nextToken()
- p.nextToken()
- }
- func (p *mathparser) parseExpression(power int) (Expression, error) {
- prefix, ok := p.prefix[p.curr.Type]
- if !ok {
- return nil, fmt.Errorf("can not parse %s as prefix operator", p.curr)
- }
- left, err := prefix()
- if err != nil {
- return nil, err
- }
- for power < p.peekPower() {
- infix, ok := p.infix[p.peek.Type]
- if !ok {
- return nil, fmt.Errorf("can not parse %s as infix operator", p.curr)
- }
- left, err = infix(left)
- if err != nil {
- return nil, err
- }
- }
- return left, nil
- }
- func (m *mathparser) nextToken() {
- m.curr = m.peek
- m.peek = m.lext.nextToken()
- }
- func NewParser(str string) *parser {
- var p parser
- p.prefix = map[int]func() (Exec, error){
- Literal: p.parseSimple,
- Builtin: p.parseSimple,
- Identifier: p.parseAssignment,
- Arithmetic: p.parseArithmetic,
- }
- p.infix = map[int]func(Exec) (Exec, error){
- And: p.parseLogical,
- Or: p.parseLogical,
- Pipeline: p.parsePipe,
- }
- p.lex = Lex(str)
- p.nextToken()
- p.nextToken()
- return &p
- }
- func (p *parser) Parse() (Exec, error) {
- xs := make(sequence, 0, 8)
- for !p.isDone() {
- e, err := p.parseCommand(bindLowest)
- if err != nil {
- return nil, err
- }
- xs = append(xs, e)
- p.nextToken()
- }
- return xs, nil
- }
- func (p *parser) parseCommand(power int) (Exec, error) {
- fmt.Println("parseCommand:", p.curr)
- prefix, ok := p.prefix[p.curr.Type]
- if !ok {
- return nil, fmt.Errorf("can not parse %s as command prefix", p.curr)
- }
- e, err := prefix()
- if err != nil {
- return nil, err
- }
- for !p.isComplete() && power < p.currPower() {
- infix, ok := p.infix[p.curr.Type]
- if !ok {
- return nil, fmt.Errorf("can not parse %s as command infix", p.peek)
- }
- e, err = infix(e)
- if err != nil {
- return nil, err
- }
- }
- return e, nil
- }
- func (p *parser) parseArithmetic() (Exec, error) {
- fmt.Println("parseArithmetic:", p.curr)
- i := literal(p.curr.Literal)
- if _, err := parseMath(p.lex); err != nil {
- return nil, err
- }
- p.nextToken()
- return i, nil
- }
- func (p *parser) parseAssignment() (Exec, error) {
- fmt.Println("parseAssignment:", p.curr)
- a := assignment{ident: p.curr.Literal}
- p.nextToken()
- if p.curr.Type != Assign {
- return nil, fmt.Errorf("invalid syntax: %s", p.curr)
- }
- var err error
- if typ := p.peek.Type; typ != EOF && typ != Sequence {
- p.nextToken()
- a.value, err = p.parseSimple()
- } else {
- p.nextToken()
- }
- return a, err
- }
- func (p *parser) parsePipe(left Exec) (Exec, error) {
- fmt.Println("parsePipe:", p.curr)
- ps := make(pipeline, 0, 8)
- ps = append(ps, left)
- for !p.isComplete() {
- p.nextToken()
- left, err := p.parseSimple()
- if err != nil {
- return nil, err
- }
- ps = append(ps, left)
- if p.curr.Type != Pipeline && p.curr.Type != EOF {
- return nil, fmt.Errorf("invalid syntax: %s", p.peek)
- }
- }
- return ps, nil
- }
- func (p *parser) parseLogical(left Exec) (Exec, error) {
- fmt.Println("parseLogical:", p.curr)
- logic := logical{
- left: left,
- op: p.curr.Type,
- }
- p.nextToken()
- right, err := p.parseSimple()
- if err != nil {
- return nil, err
- }
- logic.right = right
- return logic, nil
- }
- func (p *parser) parseSimple() (Exec, error) {
- fmt.Println("parseSimple:", p.curr)
- var args []string
- for {
- if p.isComplete() || p.isControl(p.curr) {
- break
- }
- args = append(args, p.curr.Literal)
- p.nextToken()
- }
- return newCommand(args), nil
- }
- func (p *parser) isControl(tok Token) bool {
- t := tok.Type
- return t == And || t == Or || t == Background || t == Pipeline || t == Sequence || t == EOF
- }
- func (p *parser) isComplete() bool {
- return p.curr.Type == EOF || p.curr.Type == Sequence || p.curr.Type == Comment
- }
- func (p *parser) isDone() bool {
- return p.curr.Type == EOF
- }
- func (p *parser) currPower() int {
- return bindingPower(p.curr)
- }
- func (p *parser) peekPower() int {
- return bindingPower(p.peek)
- }
- func (p *parser) nextToken() {
- if p.isDone() {
- return
- }
- p.curr = p.peek
- p.peek = p.lex.Next()
- if p.curr.Type == Comment {
- for !p.isDone() {
- p.nextToken()
- }
- }
- }
- func bindingPower(tok Token) int {
- p, ok := powers[tok.Type]
- if !ok {
- p = bindLowest
- }
- return p
- }
- const (
- eof = 0
- tab = '\t'
- space = ' '
- dollar = '$'
- semicolon = ';'
- ampersand = '&'
- pipe = '|'
- comment = '#'
- quote = '\''
- lparen = '('
- rparen = ')'
- lcurly = '{'
- rcurly = '}'
- langle = '<'
- rangle = '>'
- equal = '='
- question = '?'
- bang = '!'
- plus = '+'
- minus = '-'
- star = '*'
- slash = '/'
- )
- const (
- EOF = -(iota + 1)
- Literal
- Builtin
- Identifier
- Variable
- Arithmetic
- Comment
- And
- Or
- Pipeline
- Background
- Sequence
- Assign
- Unknown
- Invalid
- )
- var labels = []string{
- "invalid",
- "unknown",
- "assignment",
- "sequence",
- "background",
- "pipeline",
- "or",
- "and",
- "comment",
- "arithmetic",
- "variable",
- "identifier",
- "builtin",
- "literal",
- "eof",
- }
- type Token struct {
- Op byte
- Type int
- Literal string
- }
- func (t Token) String() string {
- n := len(labels) + t.Type
- var str string
- if n < 0 {
- str = labels[0]
- } else {
- str = labels[n]
- }
- return fmt.Sprintf("<%s (%s)>", str, t.Literal)
- }
- type lexerState uint16
- const (
- lexDefault lexerState = iota
- lexStrict
- lexWeak
- lexSubshell
- lexArithmetic
- )
- type lexer struct {
- input []byte
- char byte
- pos int
- next int
- scan func() Token
- }
- func Lex(str string) *lexer {
- x := lexer{input: []byte(str)}
- x.readByte()
- return &x
- }
- func (x *lexer) scanArithmetic() Token {
- var tok Token
- if b := x.peekByte(); x.char == rparen && (b == rparen || b == eof) {
- fmt.Println("all done")
- x.scan = nil
- tok.Type = EOF
- return tok
- }
- switch {
- case isDigit(x.char):
- x.readNumber(&tok)
- case isLetter(x.char) || x.char == dollar:
- if x.char == dollar {
- x.readByte()
- }
- tok.Type = Variable
- x.readLiteral(&tok)
- default:
- tok.Literal = string(x.char)
- tok.Type, tok.Op = Arithmetic, x.char
- x.readByte()
- }
- return tok
- }
- func (x *lexer) Next() Token {
- var tok Token
- if x.char == eof {
- tok.Type = EOF
- return tok
- }
- x.skipBlank()
- if x.scan != nil {
- return x.scan()
- }
- switch {
- case x.char == '$':
- x.readByte()
- if b := x.peekByte(); x.char == lcurly || isLetter(b) {
- tok.Type = Variable
- if x.char == lcurly {
- x.readByte()
- x.readUntil(&tok, func(b byte) bool { return b == rcurly })
- } else {
- x.readLiteral(&tok)
- }
- } else if x.char == lparen && b == lparen {
- tok.Type = Arithmetic
- x.readByte()
- x.scan = x.scanArithmetic
- } else if x.char == question || x.char == bang {
- tok.Type = Variable
- }
- case isControl(x.char):
- x.readControl(&tok)
- case isComment(x.char):
- x.readByte()
- x.readComment(&tok)
- default:
- x.readLiteral(&tok)
- if b := x.peekByte(); x.char == equal && (!isBlank(b) || b == eof) {
- tok.Type = Identifier
- }
- }
- x.readByte()
- return tok
- }
- func (x *lexer) readUntil(tok *Token, fn func(b byte) bool) {
- pos := x.pos
- for !fn(x.char) {
- x.readByte()
- }
- tok.Literal = string(x.input[pos:x.pos])
- }
- func (x *lexer) readNumber(tok *Token) {
- pos := x.pos
- for isDigit(x.char) {
- x.readByte()
- }
- tok.Literal = string(x.input[pos:x.pos])
- tok.Type = Literal
- }
- func (x *lexer) readLiteral(tok *Token) {
- pos := x.pos
- for {
- if isSeparator(x.char) || isComment(x.char) {
- break
- }
- x.readByte()
- }
- tok.Literal = string(x.input[pos:x.pos])
- if tok.Type >= 0 {
- tok.Type = Literal
- }
- if isSeparator(x.char) || isComment(x.char) {
- x.unreadByte()
- }
- ix := sort.SearchStrings(builtins, tok.Literal)
- if ix < len(builtins) && builtins[ix] == tok.Literal {
- tok.Type = Builtin
- }
- }
- func (x *lexer) readComment(tok *Token) {
- pos := x.pos
- for x.char != eof {
- x.readByte()
- }
- tok.Literal, tok.Type = string(x.input[pos:x.pos]), Comment
- }
- func (x *lexer) readControl(tok *Token) {
- switch x.char {
- case ampersand:
- x.readByte()
- if x.char == ampersand {
- tok.Type = And
- } else {
- tok.Type = Background
- }
- case pipe:
- x.readByte()
- if x.char == pipe {
- tok.Type = Or
- } else {
- tok.Type = Pipeline
- }
- case semicolon:
- tok.Type = Sequence
- case equal:
- tok.Type = Assign
- default:
- tok.Type = Invalid
- }
- }
- func (x *lexer) peekByte() byte {
- if x.next >= len(x.input) {
- return eof
- }
- return x.input[x.next]
- }
- func (x *lexer) readByte() {
- if x.next >= len(x.input) {
- x.char = eof
- } else {
- x.char = x.input[x.next]
- }
- x.pos = x.next
- x.next++
- }
- func (x *lexer) unreadByte() {
- x.next = x.pos
- x.pos--
- }
- func (x *lexer) skipBlank() {
- for isBlank(x.char) {
- x.readByte()
- }
- }
- func isSeparator(b byte) bool {
- return isBlank(b) || isControl(b)
- }
- func isBlank(b byte) bool {
- return b == space || b == tab || b == eof
- }
- func isMath(b byte) bool {
- return b == plus || b == minus || b == slash || b == star
- }
- func isControl(b byte) bool {
- return b == ampersand || b == pipe || b == semicolon || b == equal
- }
- func isComment(b byte) bool {
- return b == comment
- }
- func isQuote(b byte) bool {
- return b == quote
- }
- func isLetter(b byte) bool {
- return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z')
- }
- func isDigit(b byte) bool {
- return b >= '0' && b <= '9'
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement