Skip to content

Commit c4e4b68

Browse files
fix: check column references in ORDER BY (#1411) (#1915)
* fix: check column references in ORDER BY (#1411) * test: move test cases to endtoend tests * feat: add validate_order_by config option #1411 * feat: expand error message #1411 Tell the uses how to switch off validation here. * feat: add expanded error message to test #1411 * compiler: Add functions to the compiler struct Don't pass configuration around as a parameter --------- Co-authored-by: Kyle Conroy <kyle@conroy.org>
1 parent 9b9a2b6 commit c4e4b68

File tree

16 files changed

+119
-21
lines changed

16 files changed

+119
-21
lines changed

Diff for: internal/compiler/expand.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func (c *Compiler) quoteIdent(ident string) string {
5555
}
5656

5757
func (c *Compiler) expandStmt(qc *QueryCatalog, raw *ast.RawStmt, node ast.Node) ([]source.Edit, error) {
58-
tables, err := sourceTables(qc, node)
58+
tables, err := c.sourceTables(qc, node)
5959
if err != nil {
6060
return nil, err
6161
}

Diff for: internal/compiler/output_columns.go

+52-15
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ import (
1414

1515
// OutputColumns determines which columns a statement will output
1616
func (c *Compiler) OutputColumns(stmt ast.Node) ([]*catalog.Column, error) {
17-
qc, err := buildQueryCatalog(c.catalog, stmt, nil)
17+
qc, err := c.buildQueryCatalog(c.catalog, stmt, nil)
1818
if err != nil {
1919
return nil, err
2020
}
21-
cols, err := outputColumns(qc, stmt)
21+
cols, err := c.outputColumns(qc, stmt)
2222
if err != nil {
2323
return nil, err
2424
}
@@ -51,8 +51,8 @@ func hasStarRef(cf *ast.ColumnRef) bool {
5151
//
5252
// Return an error if column references are ambiguous
5353
// Return an error if column references don't exist
54-
func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
55-
tables, err := sourceTables(qc, node)
54+
func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
55+
tables, err := c.sourceTables(qc, node)
5656
if err != nil {
5757
return nil, err
5858
}
@@ -68,21 +68,50 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
6868

6969
if n.GroupClause != nil {
7070
for _, item := range n.GroupClause.Items {
71-
ref, ok := item.(*ast.ColumnRef)
72-
if !ok {
73-
continue
74-
}
75-
76-
if err := findColumnForRef(ref, tables, n); err != nil {
71+
if err := findColumnForNode(item, tables, n); err != nil {
7772
return nil, err
7873
}
7974
}
8075
}
76+
validateOrderBy := true
77+
if c.conf.StrictOrderBy != nil {
78+
validateOrderBy = *c.conf.StrictOrderBy
79+
}
80+
if validateOrderBy {
81+
if n.SortClause != nil {
82+
for _, item := range n.SortClause.Items {
83+
sb, ok := item.(*ast.SortBy)
84+
if !ok {
85+
continue
86+
}
87+
if err := findColumnForNode(sb.Node, tables, n); err != nil {
88+
return nil, fmt.Errorf("%v: if you want to skip this validation, set 'strict_order_by' to false", err)
89+
}
90+
}
91+
}
92+
if n.WindowClause != nil {
93+
for _, item := range n.WindowClause.Items {
94+
sb, ok := item.(*ast.List)
95+
if !ok {
96+
continue
97+
}
98+
for _, single := range sb.Items {
99+
caseExpr, ok := single.(*ast.CaseExpr)
100+
if !ok {
101+
continue
102+
}
103+
if err := findColumnForNode(caseExpr.Xpr, tables, n); err != nil {
104+
return nil, fmt.Errorf("%v: if you want to skip this validation, set 'strict_order_by' to false", err)
105+
}
106+
}
107+
}
108+
}
109+
}
81110

82111
// For UNION queries, targets is empty and we need to look for the
83112
// columns in Largs.
84113
if len(targets.Items) == 0 && n.Larg != nil {
85-
return outputColumns(qc, n.Larg)
114+
return c.outputColumns(qc, n.Larg)
86115
}
87116
case *ast.CallStmt:
88117
targets = &ast.List{}
@@ -303,7 +332,7 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
303332
case ast.EXISTS_SUBLINK:
304333
cols = append(cols, &Column{Name: name, DataType: "bool", NotNull: true})
305334
case ast.EXPR_SUBLINK:
306-
subcols, err := outputColumns(qc, n.Subselect)
335+
subcols, err := c.outputColumns(qc, n.Subselect)
307336
if err != nil {
308337
return nil, err
309338
}
@@ -339,7 +368,7 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
339368
cols = append(cols, col)
340369

341370
case *ast.SelectStmt:
342-
subcols, err := outputColumns(qc, n)
371+
subcols, err := c.outputColumns(qc, n)
343372
if err != nil {
344373
return nil, err
345374
}
@@ -428,7 +457,7 @@ func isTableRequired(n ast.Node, col *Column, prior int) int {
428457
// Return an error if column references don't exist
429458
// Return an error if a table is referenced twice
430459
// Return an error if an unknown column is referenced
431-
func sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, error) {
460+
func (c *Compiler) sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, error) {
432461
var list *ast.List
433462
switch n := node.(type) {
434463
case *ast.DeleteStmt:
@@ -483,7 +512,7 @@ func sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, error) {
483512
tables = append(tables, table)
484513

485514
case *ast.RangeSubselect:
486-
cols, err := outputColumns(qc, n.Subquery)
515+
cols, err := c.outputColumns(qc, n.Subquery)
487516
if err != nil {
488517
return nil, err
489518
}
@@ -581,6 +610,14 @@ func outputColumnRefs(res *ast.ResTarget, tables []*Table, node *ast.ColumnRef)
581610
return cols, nil
582611
}
583612

613+
func findColumnForNode(item ast.Node, tables []*Table, n *ast.SelectStmt) error {
614+
ref, ok := item.(*ast.ColumnRef)
615+
if !ok {
616+
return nil
617+
}
618+
return findColumnForRef(ref, tables, n)
619+
}
620+
584621
func findColumnForRef(ref *ast.ColumnRef, tables []*Table, selectStatement *ast.SelectStmt) error {
585622
parts := stringSlice(ref.Fields)
586623
var alias, name string

Diff for: internal/compiler/parse.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,8 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
8686
} else {
8787
sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number })
8888
}
89-
9089
raw, embeds := rewrite.Embeds(raw)
91-
qc, err := buildQueryCatalog(c.catalog, raw.Stmt, embeds)
90+
qc, err := c.buildQueryCatalog(c.catalog, raw.Stmt, embeds)
9291
if err != nil {
9392
return nil, err
9493
}
@@ -97,7 +96,7 @@ func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query,
9796
if err != nil {
9897
return nil, err
9998
}
100-
cols, err := outputColumns(qc, raw.Stmt)
99+
cols, err := c.outputColumns(qc, raw.Stmt)
101100
if err != nil {
102101
return nil, err
103102
}

Diff for: internal/compiler/query_catalog.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ type QueryCatalog struct {
1414
embeds rewrite.EmbedSet
1515
}
1616

17-
func buildQueryCatalog(c *catalog.Catalog, node ast.Node, embeds rewrite.EmbedSet) (*QueryCatalog, error) {
17+
func (comp *Compiler) buildQueryCatalog(c *catalog.Catalog, node ast.Node, embeds rewrite.EmbedSet) (*QueryCatalog, error) {
1818
var with *ast.WithClause
1919
switch n := node.(type) {
2020
case *ast.DeleteStmt:
@@ -32,7 +32,7 @@ func buildQueryCatalog(c *catalog.Catalog, node ast.Node, embeds rewrite.EmbedSe
3232
if with != nil {
3333
for _, item := range with.Ctes.Items {
3434
if cte, ok := item.(*ast.CommonTableExpr); ok {
35-
cols, err := outputColumns(qc, cte.Ctequery)
35+
cols, err := comp.outputColumns(qc, cte.Ctequery)
3636
if err != nil {
3737
return nil, err
3838
}

Diff for: internal/config/config.go

+1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ type SQL struct {
9999
Schema Paths `json:"schema" yaml:"schema"`
100100
Queries Paths `json:"queries" yaml:"queries"`
101101
StrictFunctionChecks bool `json:"strict_function_checks" yaml:"strict_function_checks"`
102+
StrictOrderBy *bool `json:"strict_order_by" yaml:"strict_order_by"`
102103
Gen SQLGen `json:"gen" yaml:"gen"`
103104
Codegen []Codegen `json:"codegen" yaml:"codegen"`
104105
}

Diff for: internal/config/v_one.go

+6
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ type v1PackageSettings struct {
4646
OutputQuerierFileName string `json:"output_querier_file_name,omitempty" yaml:"output_querier_file_name"`
4747
OutputFilesSuffix string `json:"output_files_suffix,omitempty" yaml:"output_files_suffix"`
4848
StrictFunctionChecks bool `json:"strict_function_checks" yaml:"strict_function_checks"`
49+
StrictOrderBy *bool `json:"strict_order_by" yaml:"strict_order_by"`
4950
QueryParameterLimit *int32 `json:"query_parameter_limit,omitempty" yaml:"query_parameter_limit"`
5051
}
5152

@@ -130,6 +131,10 @@ func (c *V1GenerateSettings) Translate() Config {
130131
}
131132

132133
for _, pkg := range c.Packages {
134+
if pkg.StrictOrderBy == nil {
135+
defaultValue := true
136+
pkg.StrictOrderBy = &defaultValue
137+
}
133138
conf.SQL = append(conf.SQL, SQL{
134139
Engine: pkg.Engine,
135140
Schema: pkg.Schema,
@@ -164,6 +169,7 @@ func (c *V1GenerateSettings) Translate() Config {
164169
},
165170
},
166171
StrictFunctionChecks: pkg.StrictFunctionChecks,
172+
StrictOrderBy: pkg.StrictOrderBy,
167173
})
168174
}
169175

Diff for: internal/config/v_two.go

+4
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ func v2ParseConfig(rd io.Reader) (Config, error) {
110110
return conf, ErrPluginNotFound
111111
}
112112
}
113+
if conf.SQL[j].StrictOrderBy == nil {
114+
defaultValidate := true
115+
conf.SQL[j].StrictOrderBy = &defaultValidate
116+
}
113117
}
114118
return conf, nil
115119
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
-- Example queries for sqlc
2+
CREATE TABLE authors (
3+
id INT
4+
);
5+
6+
-- name: ListAuthors :many
7+
SELECT id FROM authors
8+
ORDER BY adfadsf;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
version: 1
2+
packages:
3+
- path: "go"
4+
name: "querytest"
5+
engine: "postgresql"
6+
schema: "query.sql"
7+
queries: "query.sql"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# package querytest
2+
query.sql:7:1: column reference "adfadsf" not found: if you want to skip this validation, set 'strict_order_by' to false
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
-- Example queries for sqlc
2+
CREATE TABLE authors (
3+
id INT
4+
);
5+
6+
-- name: ListAuthors :many
7+
SELECT id FROM authors
8+
ORDER BY adfadsf;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
version: 1
2+
packages:
3+
- path: "go"
4+
name: "querytest"
5+
engine: "postgresql"
6+
schema: "query.sql"
7+
queries: "query.sql"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# package querytest
2+
query.sql:7:1: column reference "adfadsf" not found: if you want to skip this validation, set 'strict_order_by' to false
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
-- Example queries for sqlc
2+
CREATE TABLE authors (
3+
id INT
4+
);
5+
6+
-- name: ListAuthors :many
7+
SELECT id FROM authors
8+
ORDER BY adfadsf;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
version: 1
2+
packages:
3+
- path: "go"
4+
name: "querytest"
5+
engine: "postgresql"
6+
schema: "query.sql"
7+
queries: "query.sql"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# package querytest
2+
query.sql:7:1: column reference "adfadsf" not found: if you want to skip this validation, set 'strict_order_by' to false

0 commit comments

Comments
 (0)