Files
go_vndb/vndb.go
2026-05-07 19:34:48 +08:00

325 lines
8.9 KiB
Go

package govndb
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
)
type Client struct {
baseURL string
httpClient *http.Client
token string
headers http.Header
}
// Option customizes a Client during construction.
type Option func(*Client)
// New constructs a Kana client with optional transport, base URL and auth settings.
func New(options ...Option) *Client {
client := &Client{
baseURL: DefaultBaseURL,
httpClient: http.DefaultClient,
headers: make(http.Header),
}
for _, option := range options {
if option != nil {
option(client)
}
}
if client.httpClient == nil {
client.httpClient = http.DefaultClient
}
if client.baseURL == "" {
client.baseURL = DefaultBaseURL
}
return client
}
// WithHTTPClient overrides the HTTP client used for all requests.
func WithHTTPClient(httpClient *http.Client) Option {
return func(client *Client) {
client.httpClient = httpClient
}
}
// WithBaseURL overrides the API base URL.
//
// This is mainly useful for sandbox or test servers.
func WithBaseURL(baseURL string) Option {
return func(client *Client) {
client.baseURL = strings.TrimRight(baseURL, "/")
}
}
// WithToken adds VNDB token authentication to all requests.
func WithToken(token string) Option {
return func(client *Client) {
client.token = strings.TrimSpace(token)
}
}
// WithHeader appends a static header to all requests.
func WithHeader(key, value string) Option {
return func(client *Client) {
client.headers.Add(key, value)
}
}
// Query executes a typed Kana POST query against the given endpoint.
func (c *Client) Query(ctx context.Context, endpoint Endpoint, request QueryRequest, target any) error {
if request.Fields == "" {
request.Fields = defaultFieldsForEndpoint(endpoint)
}
if request.Fields == "" {
return fmt.Errorf("vndb: fields are required")
}
return c.doJSON(ctx, http.MethodPost, string(endpoint), request, target)
}
// Query is a generic helper that decodes a Kana query response into QueryResponse[T].
func Query[T any](ctx context.Context, client *Client, endpoint Endpoint, request QueryRequest) (*QueryResponse[T], error) {
if client == nil {
client = New()
}
var response QueryResponse[T]
if err := client.Query(ctx, endpoint, request, &response); err != nil {
return nil, err
}
return &response, nil
}
func (c *Client) QueryVNs(ctx context.Context, request QueryRequest) (*QueryResponse[VN], error) {
return Query[VN](ctx, c, EndpointVN, request)
}
func (c *Client) QueryReleases(ctx context.Context, request QueryRequest) (*QueryResponse[Release], error) {
return Query[Release](ctx, c, EndpointRelease, request)
}
func (c *Client) QueryProducers(ctx context.Context, request QueryRequest) (*QueryResponse[Producer], error) {
return Query[Producer](ctx, c, EndpointProducer, request)
}
func (c *Client) QueryCharacters(ctx context.Context, request QueryRequest) (*QueryResponse[Character], error) {
return Query[Character](ctx, c, EndpointCharacter, request)
}
func (c *Client) QueryStaff(ctx context.Context, request QueryRequest) (*QueryResponse[Staff], error) {
return Query[Staff](ctx, c, EndpointStaff, request)
}
func (c *Client) QueryTags(ctx context.Context, request QueryRequest) (*QueryResponse[Tag], error) {
return Query[Tag](ctx, c, EndpointTag, request)
}
func (c *Client) QueryTraits(ctx context.Context, request QueryRequest) (*QueryResponse[Trait], error) {
return Query[Trait](ctx, c, EndpointTrait, request)
}
func (c *Client) QueryQuotes(ctx context.Context, request QueryRequest) (*QueryResponse[Quote], error) {
return Query[Quote](ctx, c, EndpointQuote, request)
}
func (c *Client) QueryUList(ctx context.Context, request QueryRequest) (*QueryResponse[UListEntry], error) {
return Query[UListEntry](ctx, c, EndpointUList, request)
}
func (c *Client) VN(ctx context.Context, request QueryRequest, target any) error {
return c.Query(ctx, EndpointVN, request, target)
}
func (c *Client) Release(ctx context.Context, request QueryRequest, target any) error {
return c.Query(ctx, EndpointRelease, request, target)
}
func (c *Client) Producer(ctx context.Context, request QueryRequest, target any) error {
return c.Query(ctx, EndpointProducer, request, target)
}
func (c *Client) Character(ctx context.Context, request QueryRequest, target any) error {
return c.Query(ctx, EndpointCharacter, request, target)
}
func (c *Client) Staff(ctx context.Context, request QueryRequest, target any) error {
return c.Query(ctx, EndpointStaff, request, target)
}
func (c *Client) Tag(ctx context.Context, request QueryRequest, target any) error {
return c.Query(ctx, EndpointTag, request, target)
}
func (c *Client) Trait(ctx context.Context, request QueryRequest, target any) error {
return c.Query(ctx, EndpointTrait, request, target)
}
func (c *Client) Quote(ctx context.Context, request QueryRequest, target any) error {
return c.Query(ctx, EndpointQuote, request, target)
}
func (c *Client) UList(ctx context.Context, request QueryRequest, target any) error {
return c.Query(ctx, EndpointUList, request, target)
}
func (c *Client) Schema(ctx context.Context, target any) error {
return c.Get(ctx, "schema", nil, target)
}
// Stats fetches aggregate site statistics from GET /stats.
func (c *Client) Stats(ctx context.Context) (*Stats, error) {
var stats Stats
if err := c.Get(ctx, "stats", nil, &stats); err != nil {
return nil, err
}
return &stats, nil
}
// User performs a GET /user lookup for one or more user IDs or usernames.
func (c *Client) User(ctx context.Context, query UserQuery) (UserResult, error) {
values := make(url.Values, len(query.Queries)+1)
for _, item := range query.Queries {
values.Add("q", item)
}
if query.Fields == "" {
query.Fields = defaultUserFields
}
if query.Fields != "" {
values.Set("fields", query.Fields)
}
var result UserResult
if err := c.Get(ctx, "user", values, &result); err != nil {
return nil, err
}
return result, nil
}
// AuthInfo validates the configured token via GET /authinfo.
func (c *Client) AuthInfo(ctx context.Context) (*AuthInfo, error) {
var info AuthInfo
if err := c.Get(ctx, "authinfo", nil, &info); err != nil {
return nil, err
}
return &info, nil
}
// UListLabels fetches a user's list labels from GET /ulist_labels.
func (c *Client) UListLabels(ctx context.Context, query UListLabelsQuery) (*UListLabelsResponse, error) {
values := make(url.Values, 2)
if query.User != "" {
values.Set("user", query.User)
}
if query.Fields == "" {
query.Fields = defaultUListLabelFields
}
if query.Fields != "" {
values.Set("fields", query.Fields)
}
var response UListLabelsResponse
if err := c.Get(ctx, "ulist_labels", values, &response); err != nil {
return nil, err
}
return &response, nil
}
// Get executes a GET request and decodes the JSON response into target.
func (c *Client) Get(ctx context.Context, path string, query url.Values, target any) error {
return c.doJSON(ctx, http.MethodGet, pathWithQuery(path, query), nil, target)
}
func (c *Client) doJSON(ctx context.Context, method, path string, body any, target any) error {
if c == nil {
c = New()
}
requestURL, err := joinURL(c.baseURL, path)
if err != nil {
return err
}
var reader io.Reader
if body != nil {
payload, err := json.Marshal(body)
if err != nil {
return fmt.Errorf("vndb: marshal request: %w", err)
}
reader = bytes.NewReader(payload)
}
req, err := http.NewRequestWithContext(ctx, method, requestURL, reader)
if err != nil {
return fmt.Errorf("vndb: create request: %w", err)
}
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
req.Header.Set("Accept", "application/json")
if c.token != "" {
req.Header.Set("Authorization", "Token "+c.token)
}
for key, values := range c.headers {
for _, value := range values {
req.Header.Add(key, value)
}
}
resp, err := c.httpClient.Do(req)
if err != nil {
return fmt.Errorf("vndb: perform request: %w", err)
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("vndb: read response: %w", err)
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return &ErrorResponse{StatusCode: resp.StatusCode, Body: strings.TrimSpace(string(data))}
}
if target == nil || len(data) == 0 || resp.StatusCode == http.StatusNoContent {
return nil
}
if err := json.Unmarshal(data, target); err != nil {
return fmt.Errorf("vndb: decode response: %w", err)
}
return nil
}
func pathWithQuery(path string, query url.Values) string {
if len(query) == 0 {
return path
}
return path + "?" + query.Encode()
}
func joinURL(baseURL, path string) (string, error) {
base, err := url.Parse(strings.TrimRight(baseURL, "/") + "/")
if err != nil {
return "", fmt.Errorf("vndb: invalid base url: %w", err)
}
relative, err := url.Parse(strings.TrimLeft(path, "/"))
if err != nil {
return "", fmt.Errorf("vndb: invalid path: %w", err)
}
return base.ResolveReference(relative).String(), nil
}