@@ -14,11 +14,11 @@ import (
14
14
15
15
// OutputColumns determines which columns a statement will output
16
16
func (c * Compiler ) OutputColumns (stmt ast.Node ) ([]* catalog.Column , error ) {
17
- qc , err := buildQueryCatalog (c .catalog , stmt , nil )
17
+ qc , err := c . buildQueryCatalog (c .catalog , stmt , nil )
18
18
if err != nil {
19
19
return nil , err
20
20
}
21
- cols , err := outputColumns (qc , stmt )
21
+ cols , err := c . outputColumns (qc , stmt )
22
22
if err != nil {
23
23
return nil , err
24
24
}
@@ -51,8 +51,8 @@ func hasStarRef(cf *ast.ColumnRef) bool {
51
51
//
52
52
// Return an error if column references are ambiguous
53
53
// Return an error if column references don't exist
54
- func outputColumns (qc * QueryCatalog , node ast.Node ) ([]* Column , error ) {
55
- tables , err := sourceTables (qc , node )
54
+ func ( c * Compiler ) outputColumns (qc * QueryCatalog , node ast.Node ) ([]* Column , error ) {
55
+ tables , err := c . sourceTables (qc , node )
56
56
if err != nil {
57
57
return nil , err
58
58
}
@@ -68,21 +68,50 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
68
68
69
69
if n .GroupClause != nil {
70
70
for _ , item := range n .GroupClause .Items {
71
- ref , ok := item .(* ast.ColumnRef )
72
- if ! ok {
73
- continue
74
- }
75
-
76
- if err := findColumnForRef (ref , tables , n ); err != nil {
71
+ if err := findColumnForNode (item , tables , n ); err != nil {
77
72
return nil , err
78
73
}
79
74
}
80
75
}
76
+ validateOrderBy := true
77
+ if c .conf .StrictOrderBy != nil {
78
+ validateOrderBy = * c .conf .StrictOrderBy
79
+ }
80
+ if validateOrderBy {
81
+ if n .SortClause != nil {
82
+ for _ , item := range n .SortClause .Items {
83
+ sb , ok := item .(* ast.SortBy )
84
+ if ! ok {
85
+ continue
86
+ }
87
+ if err := findColumnForNode (sb .Node , tables , n ); err != nil {
88
+ return nil , fmt .Errorf ("%v: if you want to skip this validation, set 'strict_order_by' to false" , err )
89
+ }
90
+ }
91
+ }
92
+ if n .WindowClause != nil {
93
+ for _ , item := range n .WindowClause .Items {
94
+ sb , ok := item .(* ast.List )
95
+ if ! ok {
96
+ continue
97
+ }
98
+ for _ , single := range sb .Items {
99
+ caseExpr , ok := single .(* ast.CaseExpr )
100
+ if ! ok {
101
+ continue
102
+ }
103
+ if err := findColumnForNode (caseExpr .Xpr , tables , n ); err != nil {
104
+ return nil , fmt .Errorf ("%v: if you want to skip this validation, set 'strict_order_by' to false" , err )
105
+ }
106
+ }
107
+ }
108
+ }
109
+ }
81
110
82
111
// For UNION queries, targets is empty and we need to look for the
83
112
// columns in Largs.
84
113
if len (targets .Items ) == 0 && n .Larg != nil {
85
- return outputColumns (qc , n .Larg )
114
+ return c . outputColumns (qc , n .Larg )
86
115
}
87
116
case * ast.CallStmt :
88
117
targets = & ast.List {}
@@ -303,7 +332,7 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
303
332
case ast .EXISTS_SUBLINK :
304
333
cols = append (cols , & Column {Name : name , DataType : "bool" , NotNull : true })
305
334
case ast .EXPR_SUBLINK :
306
- subcols , err := outputColumns (qc , n .Subselect )
335
+ subcols , err := c . outputColumns (qc , n .Subselect )
307
336
if err != nil {
308
337
return nil , err
309
338
}
@@ -339,7 +368,7 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
339
368
cols = append (cols , col )
340
369
341
370
case * ast.SelectStmt :
342
- subcols , err := outputColumns (qc , n )
371
+ subcols , err := c . outputColumns (qc , n )
343
372
if err != nil {
344
373
return nil , err
345
374
}
@@ -428,7 +457,7 @@ func isTableRequired(n ast.Node, col *Column, prior int) int {
428
457
// Return an error if column references don't exist
429
458
// Return an error if a table is referenced twice
430
459
// Return an error if an unknown column is referenced
431
- func sourceTables (qc * QueryCatalog , node ast.Node ) ([]* Table , error ) {
460
+ func ( c * Compiler ) sourceTables (qc * QueryCatalog , node ast.Node ) ([]* Table , error ) {
432
461
var list * ast.List
433
462
switch n := node .(type ) {
434
463
case * ast.DeleteStmt :
@@ -483,7 +512,7 @@ func sourceTables(qc *QueryCatalog, node ast.Node) ([]*Table, error) {
483
512
tables = append (tables , table )
484
513
485
514
case * ast.RangeSubselect :
486
- cols , err := outputColumns (qc , n .Subquery )
515
+ cols , err := c . outputColumns (qc , n .Subquery )
487
516
if err != nil {
488
517
return nil , err
489
518
}
@@ -581,6 +610,14 @@ func outputColumnRefs(res *ast.ResTarget, tables []*Table, node *ast.ColumnRef)
581
610
return cols , nil
582
611
}
583
612
613
+ func findColumnForNode (item ast.Node , tables []* Table , n * ast.SelectStmt ) error {
614
+ ref , ok := item .(* ast.ColumnRef )
615
+ if ! ok {
616
+ return nil
617
+ }
618
+ return findColumnForRef (ref , tables , n )
619
+ }
620
+
584
621
func findColumnForRef (ref * ast.ColumnRef , tables []* Table , selectStatement * ast.SelectStmt ) error {
585
622
parts := stringSlice (ref .Fields )
586
623
var alias , name string
0 commit comments