aboutsummaryrefslogblamecommitdiffstats
path: root/vendor/gotest.tools/internal/source/source.go
blob: a05933cc3380e23dc607d4b4f7a935edb50506ec (plain) (tree)


































































































































































                                                                                           
package source // import "gotest.tools/internal/source"

import (
	"bytes"
	"fmt"
	"go/ast"
	"go/format"
	"go/parser"
	"go/token"
	"os"
	"runtime"
	"strconv"
	"strings"

	"github.com/pkg/errors"
)

const baseStackIndex = 1

// FormattedCallExprArg returns the argument from an ast.CallExpr at the
// index in the call stack. The argument is formatted using FormatNode.
func FormattedCallExprArg(stackIndex int, argPos int) (string, error) {
	args, err := CallExprArgs(stackIndex + 1)
	if err != nil {
		return "", err
	}
	return FormatNode(args[argPos])
}

func getNodeAtLine(filename string, lineNum int) (ast.Node, error) {
	fileset := token.NewFileSet()
	astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors)
	if err != nil {
		return nil, errors.Wrapf(err, "failed to parse source file: %s", filename)
	}

	node := scanToLine(fileset, astFile, lineNum)
	if node == nil {
		return nil, errors.Errorf(
			"failed to find an expression on line %d in %s", lineNum, filename)
	}
	return node, nil
}

func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node {
	v := &scanToLineVisitor{lineNum: lineNum, fileset: fileset}
	ast.Walk(v, node)
	return v.matchedNode
}

type scanToLineVisitor struct {
	lineNum     int
	matchedNode ast.Node
	fileset     *token.FileSet
}

func (v *scanToLineVisitor) Visit(node ast.Node) ast.Visitor {
	if node == nil || v.matchedNode != nil {
		return nil
	}
	if v.nodePosition(node).Line == v.lineNum {
		v.matchedNode = node
		return nil
	}
	return v
}

// In golang 1.9 the line number changed from being the line where the statement
// ended to the line where the statement began.
func (v *scanToLineVisitor) nodePosition(node ast.Node) token.Position {
	if goVersionBefore19 {
		return v.fileset.Position(node.End())
	}
	return v.fileset.Position(node.Pos())
}

var goVersionBefore19 = isGOVersionBefore19()

func isGOVersionBefore19() bool {
	version := runtime.Version()
	// not a release version
	if !strings.HasPrefix(version, "go") {
		return false
	}
	version = strings.TrimPrefix(version, "go")
	parts := strings.Split(version, ".")
	if len(parts) < 2 {
		return false
	}
	minor, err := strconv.ParseInt(parts[1], 10, 32)
	return err == nil && parts[0] == "1" && minor < 9
}

func getCallExprArgs(node ast.Node) ([]ast.Expr, error) {
	visitor := &callExprVisitor{}
	ast.Walk(visitor, node)
	if visitor.expr == nil {
		return nil, errors.New("failed to find call expression")
	}
	return visitor.expr.Args, nil
}

type callExprVisitor struct {
	expr *ast.CallExpr
}

func (v *callExprVisitor) Visit(node ast.Node) ast.Visitor {
	if v.expr != nil || node == nil {
		return nil
	}
	debug("visit (%T): %s", node, debugFormatNode{node})

	if callExpr, ok := node.(*ast.CallExpr); ok {
		v.expr = callExpr
		return nil
	}
	return v
}

// FormatNode using go/format.Node and return the result as a string
func FormatNode(node ast.Node) (string, error) {
	buf := new(bytes.Buffer)
	err := format.Node(buf, token.NewFileSet(), node)
	return buf.String(), err
}

// CallExprArgs returns the ast.Expr slice for the args of an ast.CallExpr at
// the index in the call stack.
func CallExprArgs(stackIndex int) ([]ast.Expr, error) {
	_, filename, lineNum, ok := runtime.Caller(baseStackIndex + stackIndex)
	if !ok {
		return nil, errors.New("failed to get call stack")
	}
	debug("call stack position: %s:%d", filename, lineNum)

	node, err := getNodeAtLine(filename, lineNum)
	if err != nil {
		return nil, err
	}
	debug("found node (%T): %s", node, debugFormatNode{node})

	return getCallExprArgs(node)
}

var debugEnabled = os.Getenv("GOTESTYOURSELF_DEBUG") != ""

func debug(format string, args ...interface{}) {
	if debugEnabled {
		fmt.Fprintf(os.Stderr, "DEBUG: "+format+"\n", args...)
	}
}

type debugFormatNode struct {
	ast.Node
}

func (n debugFormatNode) String() string {
	out, err := FormatNode(n.Node)
	if err != nil {
		return fmt.Sprintf("failed to format %s: %s", n.Node, err)
	}
	return out
}