Skip to content

Commit cbe4419

Browse files
authored
feat(civil): implement database/sql.Scanner|Valuer (#1145) (#11808)
Fixes #1145. I wasn't sure how to test these in the context of this repo. Questions: - I tried to be comprehensive in the Scan functions, I'm not sure exactly what types of values to expect - For the Valuer, I assumed that most/all databases should accept the ISO string values, but that may not be globally true
1 parent 31cd272 commit cbe4419

File tree

2 files changed

+225
-0
lines changed

2 files changed

+225
-0
lines changed

civil/civil.go

+121
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
package civil
2323

2424
import (
25+
"database/sql/driver"
2526
"fmt"
2627
"time"
2728
)
@@ -158,6 +159,46 @@ func (d *Date) UnmarshalText(data []byte) error {
158159
return err
159160
}
160161

162+
// Value implements the database/sql/driver Valuer interface.
163+
func (d Date) Value() (driver.Value, error) {
164+
return d.String(), nil
165+
}
166+
167+
// Scan implements the database/sql Scanner interface.
168+
func (d *Date) Scan(v any) error {
169+
switch vt := v.(type) {
170+
case time.Time:
171+
*d = DateOf(vt)
172+
case *time.Time:
173+
if vt != nil {
174+
*d = DateOf(*vt)
175+
}
176+
case string:
177+
var err error
178+
*d, err = ParseDate(vt)
179+
return err
180+
case *string:
181+
var err error
182+
if vt != nil {
183+
*d, err = ParseDate(*vt)
184+
}
185+
return err
186+
case []byte:
187+
var err error
188+
*d, err = ParseDate(string(vt))
189+
return err
190+
case *[]byte:
191+
var err error
192+
if vt != nil {
193+
*d, err = ParseDate(string(*vt))
194+
}
195+
return err
196+
default:
197+
return fmt.Errorf("unsupported scan type for Date: %T", v)
198+
}
199+
return nil
200+
}
201+
161202
// A Time represents a time with nanosecond precision.
162203
//
163204
// This type does not include location information, and therefore does not
@@ -262,6 +303,46 @@ func (t *Time) UnmarshalText(data []byte) error {
262303
return err
263304
}
264305

306+
// Value implements the database/sql/driver Valuer interface.
307+
func (t Time) Value() (driver.Value, error) {
308+
return t.String(), nil
309+
}
310+
311+
// Scan implements the database/sql Scanner interface.
312+
func (t *Time) Scan(v any) error {
313+
switch vt := v.(type) {
314+
case time.Time:
315+
*t = TimeOf(vt)
316+
case *time.Time:
317+
if vt != nil {
318+
*t = TimeOf(*vt)
319+
}
320+
case string:
321+
var err error
322+
*t, err = ParseTime(vt)
323+
return err
324+
case *string:
325+
var err error
326+
if vt != nil {
327+
*t, err = ParseTime(*vt)
328+
}
329+
return err
330+
case []byte:
331+
var err error
332+
*t, err = ParseTime(string(vt))
333+
return err
334+
case *[]byte:
335+
var err error
336+
if vt != nil {
337+
*t, err = ParseTime(string(*vt))
338+
}
339+
return err
340+
default:
341+
return fmt.Errorf("unsupported scan type for Time: %T", v)
342+
}
343+
return nil
344+
}
345+
265346
// A DateTime represents a date and time.
266347
//
267348
// This type does not include location information, and therefore does not
@@ -365,3 +446,43 @@ func (dt *DateTime) UnmarshalText(data []byte) error {
365446
*dt, err = ParseDateTime(string(data))
366447
return err
367448
}
449+
450+
// Value implements the database/sql/driver Valuer interface.
451+
func (dt DateTime) Value() (driver.Value, error) {
452+
return dt.String(), nil
453+
}
454+
455+
// Scan implements the database/sql Scanner interface.
456+
func (dt *DateTime) Scan(v any) error {
457+
switch vt := v.(type) {
458+
case time.Time:
459+
*dt = DateTimeOf(vt)
460+
case *time.Time:
461+
if vt != nil {
462+
*dt = DateTimeOf(*vt)
463+
}
464+
case string:
465+
var err error
466+
*dt, err = ParseDateTime(vt)
467+
return err
468+
case *string:
469+
var err error
470+
if vt != nil {
471+
*dt, err = ParseDateTime(*vt)
472+
}
473+
return err
474+
case []byte:
475+
var err error
476+
*dt, err = ParseDateTime(string(vt))
477+
return err
478+
case *[]byte:
479+
var err error
480+
if vt != nil {
481+
*dt, err = ParseDateTime(string(*vt))
482+
}
483+
return err
484+
default:
485+
return fmt.Errorf("unsupported scan type for DateTime: %T", v)
486+
}
487+
return nil
488+
}

civil/civil_test.go

+104
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
package civil
1616

1717
import (
18+
"database/sql"
19+
"database/sql/driver"
1820
"encoding/json"
1921
"testing"
2022
"time"
@@ -705,3 +707,105 @@ func TestUnmarshalJSON(t *testing.T) {
705707
}
706708
}
707709
}
710+
711+
func TestValuer(t *testing.T) {
712+
for _, test := range []struct {
713+
data driver.Valuer
714+
want interface{}
715+
}{
716+
{&Date{1987, 4, 15}, `1987-04-15`},
717+
{&Time{18, 54, 2, 0}, `18:54:02`},
718+
{&DateTime{Date{1987, 4, 15}, Time{18, 54, 2, 0}}, `1987-04-15T18:54:02`},
719+
} {
720+
got, err := test.data.Value()
721+
if err != nil {
722+
t.Fatalf("%s: %v", test.data, err)
723+
}
724+
if !cmp.Equal(got, test.want) {
725+
t.Errorf("%s: got %#v, want %#v", test.data, test.data, test.want)
726+
}
727+
}
728+
}
729+
730+
func TestScanner(t *testing.T) {
731+
var d Date
732+
var tm Time
733+
var dt DateTime
734+
for _, test := range []struct {
735+
data interface{}
736+
ptr sql.Scanner
737+
want interface{}
738+
}{
739+
// time input
740+
{time.Date(1987, 4, 15, 18, 54, 2, 0, time.UTC), &d, &Date{1987, 4, 15}},
741+
{time.Date(1987, 4, 15, 18, 54, 2, 0, time.UTC), &tm, &Time{18, 54, 2, 0}},
742+
{time.Date(1987, 4, 15, 18, 54, 2, 0, time.UTC), &dt, &DateTime{Date{1987, 4, 15}, Time{18, 54, 2, 0}}},
743+
744+
// *time input
745+
{toPtr(time.Date(1987, 4, 15, 18, 54, 2, 0, time.UTC)), &d, &Date{1987, 4, 15}},
746+
{toPtr(time.Date(1987, 4, 15, 18, 54, 2, 0, time.UTC)), &tm, &Time{18, 54, 2, 0}},
747+
{toPtr(time.Date(1987, 4, 15, 18, 54, 2, 0, time.UTC)), &dt, &DateTime{Date{1987, 4, 15}, Time{18, 54, 2, 0}}},
748+
749+
// string input
750+
{`1987-04-15`, &d, &Date{1987, 4, 15}},
751+
{`18:54:02`, &tm, &Time{18, 54, 2, 0}},
752+
{`1987-04-15T18:54:02`, &dt, &DateTime{Date{1987, 4, 15}, Time{18, 54, 2, 0}}},
753+
754+
// *string input
755+
{toPtr(`1987-04-15`), &d, &Date{1987, 4, 15}},
756+
{toPtr(`18:54:02`), &tm, &Time{18, 54, 2, 0}},
757+
{toPtr(`1987-04-15T18:54:02`), &dt, &DateTime{Date{1987, 4, 15}, Time{18, 54, 2, 0}}},
758+
759+
// []byte input
760+
{[]byte(`1987-04-15`), &d, &Date{1987, 4, 15}},
761+
{[]byte(`18:54:02`), &tm, &Time{18, 54, 2, 0}},
762+
{[]byte(`1987-04-15T18:54:02`), &dt, &DateTime{Date{1987, 4, 15}, Time{18, 54, 2, 0}}},
763+
764+
// *[]byte input
765+
{toPtr([]byte(`1987-04-15`)), &d, &Date{1987, 4, 15}},
766+
{toPtr([]byte(`18:54:02`)), &tm, &Time{18, 54, 2, 0}},
767+
{toPtr([]byte(`1987-04-15T18:54:02`)), &dt, &DateTime{Date{1987, 4, 15}, Time{18, 54, 2, 0}}},
768+
} {
769+
if err := test.ptr.Scan(test.data); err != nil {
770+
t.Fatalf("%s: %v", test.data, err)
771+
}
772+
if !cmp.Equal(test.ptr, test.want) {
773+
t.Errorf("%s: got %#v, want %#v", test.data, test.ptr, test.want)
774+
}
775+
}
776+
777+
// expected test failures
778+
for _, test := range []struct {
779+
data interface{}
780+
ptr sql.Scanner
781+
want string
782+
}{
783+
// int64 input
784+
{int64(12345), &d, "unsupported scan type for Date: int64"},
785+
{int64(12345), &tm, "unsupported scan type for Time: int64"},
786+
{int64(12345), &dt, "unsupported scan type for DateTime: int64"},
787+
788+
// float64 input
789+
{float64(0.9876), &d, "unsupported scan type for Date: float64"},
790+
{float64(0.9876), &tm, "unsupported scan type for Time: float64"},
791+
{float64(0.9876), &dt, "unsupported scan type for DateTime: float64"},
792+
793+
// bool input
794+
{true, &d, "unsupported scan type for Date: bool"},
795+
{true, &tm, "unsupported scan type for Time: bool"},
796+
{true, &dt, "unsupported scan type for DateTime: bool"},
797+
} {
798+
err := test.ptr.Scan(test.data)
799+
if err == nil {
800+
t.Errorf("%q, got nil, want error", test.data)
801+
continue
802+
}
803+
if err.Error() != test.want {
804+
t.Errorf("%v: got %s, want %s", test.data, err, test.want)
805+
}
806+
}
807+
}
808+
809+
func toPtr[V any](v V) *V {
810+
return &v
811+
}

0 commit comments

Comments
 (0)