const std = @import("../std.zig");
const builtin = @import("builtin");
const testing = std.testing;
const http = std.http;
const mem = std.mem;
const net = std.net;
const Uri = std.Uri;
const Allocator = mem.Allocator;
const assert = std.debug.assert;
const use_vectors = builtin.zig_backend != .stage2_x86_64;
const Client = @This();
const proto = @import("protocol.zig");
pub const disable_tls = std.options.http_disable_tls;
allocator: Allocator,
ca_bundle: if (disable_tls) void else std.crypto.Certificate.Bundle = if (disable_tls) {} else .{},
ca_bundle_mutex: std.Thread.Mutex = .{},
next_https_rescan_certs: bool = true,
connection_pool: ConnectionPool = .{},
http_proxy: ?*Proxy = null,
https_proxy: ?*Proxy = null,
pub const ConnectionPool = struct {
mutex: std.Thread.Mutex = .{},
used: Queue = .{},
free: Queue = .{},
free_len: usize = 0,
free_size: usize = 32,
pub const Criteria = struct {
host: []const u8,
port: u16,
protocol: Connection.Protocol,
};
const Queue = std.DoublyLinkedList(Connection);
pub const Node = Queue.Node;
pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Connection {
pool.mutex.lock();
defer pool.mutex.unlock();
var next = pool.free.last;
while (next) |node| : (next = node.prev) {
if (node.data.protocol != criteria.protocol) continue;
if (node.data.port != criteria.port) continue;
if (!std.ascii.eqlIgnoreCase(node.data.host, criteria.host)) continue;
pool.acquireUnsafe(node);
return &node.data;
}
return null;
}
pub fn acquireUnsafe(pool: *ConnectionPool, node: *Node) void {
pool.free.remove(node);
pool.free_len -= 1;
pool.used.append(node);
}
pub fn acquire(pool: *ConnectionPool, node: *Node) void {
pool.mutex.lock();
defer pool.mutex.unlock();
return pool.acquireUnsafe(node);
}
pub fn release(pool: *ConnectionPool, allocator: Allocator, connection: *Connection) void {
pool.mutex.lock();
defer pool.mutex.unlock();
const node = @fieldParentPtr(Node, "data", connection);
pool.used.remove(node);
if (node.data.closing or pool.free_size == 0) {
node.data.close(allocator);
return allocator.destroy(node);
}
if (pool.free_len >= pool.free_size) {
const popped = pool.free.popFirst() orelse unreachable;
pool.free_len -= 1;
popped.data.close(allocator);
allocator.destroy(popped);
}
if (node.data.proxied) {
pool.free.prepend(node);
} else {
pool.free.append(node);
}
pool.free_len += 1;
}
pub fn addUsed(pool: *ConnectionPool, node: *Node) void {
pool.mutex.lock();
defer pool.mutex.unlock();
pool.used.append(node);
}
pub fn resize(pool: *ConnectionPool, allocator: Allocator, new_size: usize) void {
pool.mutex.lock();
defer pool.mutex.unlock();
const next = pool.free.first;
_ = next;
while (pool.free_len > new_size) {
const popped = pool.free.popFirst() orelse unreachable;
pool.free_len -= 1;
popped.data.close(allocator);
allocator.destroy(popped);
}
pool.free_size = new_size;
}
pub fn deinit(pool: *ConnectionPool, allocator: Allocator) void {
pool.mutex.lock();
var next = pool.free.first;
while (next) |node| {
defer allocator.destroy(node);
next = node.next;
node.data.close(allocator);
}
next = pool.used.first;
while (next) |node| {
defer allocator.destroy(node);
next = node.next;
node.data.close(allocator);
}
pool.* = undefined;
}
};
pub const Connection = struct {
stream: net.Stream,
tls_client: if (!disable_tls) *std.crypto.tls.Client else void,
protocol: Protocol,
host: []u8,
port: u16,
proxied: bool = false,
closing: bool = false,
read_start: BufferSize = 0,
read_end: BufferSize = 0,
write_end: BufferSize = 0,
read_buf: [buffer_size]u8 = undefined,
write_buf: [buffer_size]u8 = undefined,
pub const buffer_size = std.crypto.tls.max_ciphertext_record_len;
const BufferSize = std.math.IntFittingRange(0, buffer_size);
pub const Protocol = enum { plain, tls };
pub fn readvDirectTls(conn: *Connection, buffers: []std.os.iovec) ReadError!usize {
return conn.tls_client.readv(conn.stream, buffers) catch |err| {
if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert;
switch (err) {
error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure,
error.ConnectionTimedOut => return error.ConnectionTimedOut,
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
else => return error.UnexpectedReadFailure,
}
};
}
pub fn readvDirect(conn: *Connection, buffers: []std.os.iovec) ReadError!usize {
if (conn.protocol == .tls) {
if (disable_tls) unreachable;
return conn.readvDirectTls(buffers);
}
return conn.stream.readv(buffers) catch |err| switch (err) {
error.ConnectionTimedOut => return error.ConnectionTimedOut,
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
else => return error.UnexpectedReadFailure,
};
}
pub fn fill(conn: *Connection) ReadError!void {
if (conn.read_end != conn.read_start) return;
var iovecs = [1]std.os.iovec{
.{ .iov_base = &conn.read_buf, .iov_len = conn.read_buf.len },
};
const nread = try conn.readvDirect(&iovecs);
if (nread == 0) return error.EndOfStream;
conn.read_start = 0;
conn.read_end = @intCast(nread);
}
pub fn peek(conn: *Connection) []const u8 {
return conn.read_buf[conn.read_start..conn.read_end];
}
pub fn drop(conn: *Connection, num: BufferSize) void {
conn.read_start += num;
}
pub fn read(conn: *Connection, buffer: []u8) ReadError!usize {
const available_read = conn.read_end - conn.read_start;
const available_buffer = buffer.len;
if (available_read > available_buffer) {
@memcpy(buffer[0..available_buffer], conn.read_buf[conn.read_start..conn.read_end][0..available_buffer]);
conn.read_start += @intCast(available_buffer);
return available_buffer;
} else if (available_read > 0) {
@memcpy(buffer[0..available_read], conn.read_buf[conn.read_start..conn.read_end]);
conn.read_start += available_read;
return available_read;
}
var iovecs = [2]std.os.iovec{
.{ .iov_base = buffer.ptr, .iov_len = buffer.len },
.{ .iov_base = &conn.read_buf, .iov_len = conn.read_buf.len },
};
const nread = try conn.readvDirect(&iovecs);
if (nread > buffer.len) {
conn.read_start = 0;
conn.read_end = @intCast(nread - buffer.len);
return buffer.len;
}
return nread;
}
pub const ReadError = error{
TlsFailure,
TlsAlert,
ConnectionTimedOut,
ConnectionResetByPeer,
UnexpectedReadFailure,
EndOfStream,
};
pub const Reader = std.io.Reader(*Connection, ReadError, read);
pub fn reader(conn: *Connection) Reader {
return Reader{ .context = conn };
}
pub fn writeAllDirectTls(conn: *Connection, buffer: []const u8) WriteError!void {
return conn.tls_client.writeAll(conn.stream, buffer) catch |err| switch (err) {
error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
else => return error.UnexpectedWriteFailure,
};
}
pub fn writeAllDirect(conn: *Connection, buffer: []const u8) WriteError!void {
if (conn.protocol == .tls) {
if (disable_tls) unreachable;
return conn.writeAllDirectTls(buffer);
}
return conn.stream.writeAll(buffer) catch |err| switch (err) {
error.BrokenPipe, error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
else => return error.UnexpectedWriteFailure,
};
}
pub fn write(conn: *Connection, buffer: []const u8) WriteError!usize {
if (conn.write_buf.len - conn.write_end < buffer.len) {
try conn.flush();
if (buffer.len > conn.write_buf.len) {
try conn.writeAllDirect(buffer);
return buffer.len;
}
}
@memcpy(conn.write_buf[conn.write_end..][0..buffer.len], buffer);
conn.write_end += @intCast(buffer.len);
return buffer.len;
}
pub fn allocWriteBuffer(conn: *Connection, len: BufferSize) WriteError![]u8 {
if (conn.write_buf.len - conn.write_end < len) try conn.flush();
defer conn.write_end += len;
return conn.write_buf[conn.write_end..][0..len];
}
pub fn flush(conn: *Connection) WriteError!void {
if (conn.write_end == 0) return;
try conn.writeAllDirect(conn.write_buf[0..conn.write_end]);
conn.write_end = 0;
}
pub const WriteError = error{
ConnectionResetByPeer,
UnexpectedWriteFailure,
};
pub const Writer = std.io.Writer(*Connection, WriteError, write);
pub fn writer(conn: *Connection) Writer {
return Writer{ .context = conn };
}
pub fn close(conn: *Connection, allocator: Allocator) void {
if (conn.protocol == .tls) {
if (disable_tls) unreachable;
_ = conn.tls_client.writeEnd(conn.stream, "", true) catch {};
allocator.destroy(conn.tls_client);
}
conn.stream.close();
allocator.free(conn.host);
}
};
pub const RequestTransfer = union(enum) {
content_length: u64,
chunked: void,
none: void,
};
pub const Compression = union(enum) {
pub const DeflateDecompressor = std.compress.zlib.Decompressor(Request.TransferReader);
pub const GzipDecompressor = std.compress.gzip.Decompressor(Request.TransferReader);
deflate: DeflateDecompressor,
gzip: GzipDecompressor,
none: void,
};
pub const Response = struct {
version: http.Version,
status: http.Status,
reason: []const u8,
location: ?[]const u8 = null,
content_type: ?[]const u8 = null,
content_disposition: ?[]const u8 = null,
keep_alive: bool,
content_length: ?u64 = null,
transfer_encoding: http.TransferEncoding = .none,
transfer_compression: http.ContentEncoding = .identity,
parser: proto.HeadersParser,
compression: Compression = .none,
skip: bool = false,
pub const ParseError = error{
HttpHeadersInvalid,
HttpHeaderContinuationsUnsupported,
HttpTransferEncodingUnsupported,
HttpConnectionHeaderUnsupported,
InvalidContentLength,
CompressionUnsupported,
};
pub fn parse(res: *Response, bytes: []const u8) ParseError!void {
var it = mem.splitSequence(u8, bytes, "\r\n");
const first_line = it.next().?;
if (first_line.len < 12) {
return error.HttpHeadersInvalid;
}
const version: http.Version = switch (int64(first_line[0..8])) {
int64("HTTP/1.0") => .@"HTTP/1.0",
int64("HTTP/1.1") => .@"HTTP/1.1",
else => return error.HttpHeadersInvalid,
};
if (first_line[8] != ' ') return error.HttpHeadersInvalid;
const status: http.Status = @enumFromInt(parseInt3(first_line[9..12]));
const reason = mem.trimLeft(u8, first_line[12..], " ");
res.version = version;
res.status = status;
res.reason = reason;
res.keep_alive = switch (version) {
.@"HTTP/1.0" => false,
.@"HTTP/1.1" => true,
};
while (it.next()) |line| {
if (line.len == 0) return;
switch (line[0]) {
' ', '\t' => return error.HttpHeaderContinuationsUnsupported,
else => {},
}
var line_it = mem.splitScalar(u8, line, ':');
const header_name = line_it.next().?;
const header_value = mem.trim(u8, line_it.rest(), " \t");
if (header_name.len == 0) return error.HttpHeadersInvalid;
if (std.ascii.eqlIgnoreCase(header_name, "connection")) {
res.keep_alive = !std.ascii.eqlIgnoreCase(header_value, "close");
} else if (std.ascii.eqlIgnoreCase(header_name, "content-type")) {
res.content_type = header_value;
} else if (std.ascii.eqlIgnoreCase(header_name, "location")) {
res.location = header_value;
} else if (std.ascii.eqlIgnoreCase(header_name, "content-disposition")) {
res.content_disposition = header_value;
} else if (std.ascii.eqlIgnoreCase(header_name, "transfer-encoding")) {
var iter = mem.splitBackwardsScalar(u8, header_value, ',');
const first = iter.first();
const trimmed_first = mem.trim(u8, first, " ");
var next: ?[]const u8 = first;
if (std.meta.stringToEnum(http.TransferEncoding, trimmed_first)) |transfer| {
if (res.transfer_encoding != .none) return error.HttpHeadersInvalid;
res.transfer_encoding = transfer;
next = iter.next();
}
if (next) |second| {
const trimmed_second = mem.trim(u8, second, " ");
if (std.meta.stringToEnum(http.ContentEncoding, trimmed_second)) |transfer| {
if (res.transfer_compression != .identity) return error.HttpHeadersInvalid;
res.transfer_compression = transfer;
} else {
return error.HttpTransferEncodingUnsupported;
}
}
if (iter.next()) |_| return error.HttpTransferEncodingUnsupported;
} else if (std.ascii.eqlIgnoreCase(header_name, "content-length")) {
const content_length = std.fmt.parseInt(u64, header_value, 10) catch return error.InvalidContentLength;
if (res.content_length != null and res.content_length != content_length) return error.HttpHeadersInvalid;
res.content_length = content_length;
} else if (std.ascii.eqlIgnoreCase(header_name, "content-encoding")) {
if (res.transfer_compression != .identity) return error.HttpHeadersInvalid;
const trimmed = mem.trim(u8, header_value, " ");
if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
res.transfer_compression = ce;
} else {
return error.HttpTransferEncodingUnsupported;
}
}
}
return error.HttpHeadersInvalid;
}
test parse {
const response_bytes = "HTTP/1.1 200 OK\r\n" ++
"LOcation:url\r\n" ++
"content-tYpe: text/plain\r\n" ++
"content-disposition:attachment; filename=example.txt \r\n" ++
"content-Length:10\r\n" ++
"TRansfer-encoding:\tdeflate, chunked \r\n" ++
"connectioN:\t keep-alive \r\n\r\n";
var header_buffer: [1024]u8 = undefined;
var res = Response{
.status = undefined,
.reason = undefined,
.version = undefined,
.keep_alive = false,
.parser = proto.HeadersParser.init(&header_buffer),
};
@memcpy(header_buffer[0..response_bytes.len], response_bytes);
res.parser.header_bytes_len = response_bytes.len;
try res.parse(response_bytes);
try testing.expectEqual(.@"HTTP/1.1", res.version);
try testing.expectEqualStrings("OK", res.reason);
try testing.expectEqual(.ok, res.status);
try testing.expectEqualStrings("url", res.location.?);
try testing.expectEqualStrings("text/plain", res.content_type.?);
try testing.expectEqualStrings("attachment; filename=example.txt", res.content_disposition.?);
try testing.expectEqual(true, res.keep_alive);
try testing.expectEqual(10, res.content_length.?);
try testing.expectEqual(.chunked, res.transfer_encoding);
try testing.expectEqual(.deflate, res.transfer_compression);
}
inline fn int64(array: *const [8]u8) u64 {
return @bitCast(array.*);
}
fn parseInt3(text: *const [3]u8) u10 {
if (use_vectors) {
const nnn: @Vector(3, u8) = text.*;
const zero: @Vector(3, u8) = .{ '0', '0', '0' };
const mmm: @Vector(3, u10) = .{ 100, 10, 1 };
return @reduce(.Add, @as(@Vector(3, u10), nnn -% zero) *% mmm);
}
return std.fmt.parseInt(u10, text, 10) catch unreachable;
}
test parseInt3 {
const expectEqual = testing.expectEqual;
try expectEqual(@as(u10, 0), parseInt3("000"));
try expectEqual(@as(u10, 418), parseInt3("418"));
try expectEqual(@as(u10, 999), parseInt3("999"));
}
pub fn iterateHeaders(r: Response) http.HeaderIterator {
return http.HeaderIterator.init(r.parser.get());
}
test iterateHeaders {
const response_bytes = "HTTP/1.1 200 OK\r\n" ++
"LOcation:url\r\n" ++
"content-tYpe: text/plain\r\n" ++
"content-disposition:attachment; filename=example.txt \r\n" ++
"content-Length:10\r\n" ++
"TRansfer-encoding:\tdeflate, chunked \r\n" ++
"connectioN:\t keep-alive \r\n\r\n";
var header_buffer: [1024]u8 = undefined;
var res = Response{
.status = undefined,
.reason = undefined,
.version = undefined,
.keep_alive = false,
.parser = proto.HeadersParser.init(&header_buffer),
};
@memcpy(header_buffer[0..response_bytes.len], response_bytes);
res.parser.header_bytes_len = response_bytes.len;
var it = res.iterateHeaders();
{
const header = it.next().?;
try testing.expectEqualStrings("LOcation", header.name);
try testing.expectEqualStrings("url", header.value);
try testing.expect(!it.is_trailer);
}
{
const header = it.next().?;
try testing.expectEqualStrings("content-tYpe", header.name);
try testing.expectEqualStrings("text/plain", header.value);
try testing.expect(!it.is_trailer);
}
{
const header = it.next().?;
try testing.expectEqualStrings("content-disposition", header.name);
try testing.expectEqualStrings("attachment; filename=example.txt", header.value);
try testing.expect(!it.is_trailer);
}
{
const header = it.next().?;
try testing.expectEqualStrings("content-Length", header.name);
try testing.expectEqualStrings("10", header.value);
try testing.expect(!it.is_trailer);
}
{
const header = it.next().?;
try testing.expectEqualStrings("TRansfer-encoding", header.name);
try testing.expectEqualStrings("deflate, chunked", header.value);
try testing.expect(!it.is_trailer);
}
{
const header = it.next().?;
try testing.expectEqualStrings("connectioN", header.name);
try testing.expectEqualStrings("keep-alive", header.value);
try testing.expect(!it.is_trailer);
}
try testing.expectEqual(null, it.next());
}
};
pub const Request = struct {
uri: Uri,
client: *Client,
connection: ?*Connection,
keep_alive: bool,
method: http.Method,
version: http.Version = .@"HTTP/1.1",
transfer_encoding: RequestTransfer,
redirect_behavior: RedirectBehavior,
handle_continue: bool,
response: Response,
headers: Headers,
extra_headers: []const http.Header,
privileged_headers: []const http.Header,
pub const Headers = struct {
host: Value = .default,
authorization: Value = .default,
user_agent: Value = .default,
connection: Value = .default,
accept_encoding: Value = .default,
content_type: Value = .default,
pub const Value = union(enum) {
default,
omit,
override: []const u8,
};
};
pub const RedirectBehavior = enum(u16) {
not_allowed = 0,
unhandled = std.math.maxInt(u16),
_,
pub fn subtractOne(rb: *RedirectBehavior) void {
switch (rb.*) {
.not_allowed => unreachable,
.unhandled => unreachable,
_ => rb.* = @enumFromInt(@intFromEnum(rb.*) - 1),
}
}
pub fn remaining(rb: RedirectBehavior) u16 {
assert(rb != .unhandled);
return @intFromEnum(rb);
}
};
pub fn deinit(req: *Request) void {
if (req.connection) |connection| {
if (!req.response.parser.done) {
connection.closing = true;
}
req.client.connection_pool.release(req.client.allocator, connection);
}
req.* = undefined;
}
fn redirect(req: *Request, uri: Uri) !void {
assert(req.response.parser.done);
req.client.connection_pool.release(req.client.allocator, req.connection.?);
req.connection = null;
const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme;
const port: u16 = uri.port orelse switch (protocol) {
.plain => 80,
.tls => 443,
};
const host = uri.host orelse return error.UriMissingHost;
req.uri = uri;
req.connection = try req.client.connect(host, port, protocol);
req.redirect_behavior.subtractOne();
req.response.parser.reset();
req.response = .{
.version = undefined,
.status = undefined,
.reason = undefined,
.keep_alive = undefined,
.parser = req.response.parser,
};
}
pub const SendError = Connection.WriteError || error{ InvalidContentLength, UnsupportedTransferEncoding };
pub const SendOptions = struct {
raw_uri: bool = false,
};
pub fn send(req: *Request, options: SendOptions) SendError!void {
if (!req.method.requestHasBody() and req.transfer_encoding != .none)
return error.UnsupportedTransferEncoding;
const connection = req.connection.?;
const w = connection.writer();
try req.method.write(w);
try w.writeByte(' ');
if (req.method == .CONNECT) {
try req.uri.writeToStream(.{ .authority = true }, w);
} else {
try req.uri.writeToStream(.{
.scheme = connection.proxied,
.authentication = connection.proxied,
.authority = connection.proxied,
.path = true,
.query = true,
.raw = options.raw_uri,
}, w);
}
try w.writeByte(' ');
try w.writeAll(@tagName(req.version));
try w.writeAll("\r\n");
if (try emitOverridableHeader("host: ", req.headers.host, w)) {
try w.writeAll("host: ");
try req.uri.writeToStream(.{ .authority = true }, w);
try w.writeAll("\r\n");
}
if (try emitOverridableHeader("authorization: ", req.headers.authorization, w)) {
if (req.uri.user != null or req.uri.password != null) {
try w.writeAll("authorization: ");
const authorization = try connection.allocWriteBuffer(
@intCast(basic_authorization.valueLengthFromUri(req.uri)),
);
assert(basic_authorization.value(req.uri, authorization).len == authorization.len);
try w.writeAll("\r\n");
}
}
if (try emitOverridableHeader("user-agent: ", req.headers.user_agent, w)) {
try w.writeAll("user-agent: zig/");
try w.writeAll(builtin.zig_version_string);
try w.writeAll(" (std.http)\r\n");
}
if (try emitOverridableHeader("connection: ", req.headers.connection, w)) {
if (req.keep_alive) {
try w.writeAll("connection: keep-alive\r\n");
} else {
try w.writeAll("connection: close\r\n");
}
}
if (try emitOverridableHeader("accept-encoding: ", req.headers.accept_encoding, w)) {
try w.writeAll("accept-encoding: gzip, deflate\r\n");
}
switch (req.transfer_encoding) {
.chunked => try w.writeAll("transfer-encoding: chunked\r\n"),
.content_length => |len| try w.print("content-length: {d}\r\n", .{len}),
.none => {},
}
if (try emitOverridableHeader("content-type: ", req.headers.content_type, w)) {
}
for (req.extra_headers) |header| {
assert(header.name.len != 0);
try w.writeAll(header.name);
try w.writeAll(": ");
try w.writeAll(header.value);
try w.writeAll("\r\n");
}
if (connection.proxied) proxy: {
const proxy = switch (connection.protocol) {
.plain => req.client.http_proxy,
.tls => req.client.https_proxy,
} orelse break :proxy;
const authorization = proxy.authorization orelse break :proxy;
try w.writeAll("proxy-authorization: ");
try w.writeAll(authorization);
try w.writeAll("\r\n");
}
try w.writeAll("\r\n");
try connection.flush();
}
fn emitOverridableHeader(prefix: []const u8, v: Headers.Value, w: anytype) !bool {
switch (v) {
.default => return true,
.omit => return false,
.override => |x| {
try w.writeAll(prefix);
try w.writeAll(x);
try w.writeAll("\r\n");
return false;
},
}
}
const TransferReadError = Connection.ReadError || proto.HeadersParser.ReadError;
const TransferReader = std.io.Reader(*Request, TransferReadError, transferRead);
fn transferReader(req: *Request) TransferReader {
return .{ .context = req };
}
fn transferRead(req: *Request, buf: []u8) TransferReadError!usize {
if (req.response.parser.done) return 0;
var index: usize = 0;
while (index == 0) {
const amt = try req.response.parser.read(req.connection.?, buf[index..], req.response.skip);
if (amt == 0 and req.response.parser.done) break;
index += amt;
}
return index;
}
pub const WaitError = RequestError || SendError || TransferReadError ||
proto.HeadersParser.CheckCompleteHeadError || Response.ParseError ||
error{
TooManyHttpRedirects,
RedirectRequiresResend,
HttpRedirectLocationMissing,
HttpRedirectLocationInvalid,
CompressionInitializationFailed,
CompressionUnsupported,
};
pub fn wait(req: *Request) WaitError!void {
while (true) {
const connection = req.connection.?;
while (true) {
try connection.fill();
const nchecked = try req.response.parser.checkCompleteHead(connection.peek());
connection.drop(@intCast(nchecked));
if (req.response.parser.state.isContent()) break;
}
try req.response.parse(req.response.parser.get());
if (req.response.status == .@"continue") {
req.response.parser.done = true;
req.response.parser.reset();
if (req.handle_continue)
continue;
return;
}
if (req.method == .CONNECT and req.response.status.class() == .success) {
connection.closing = false;
req.response.parser.done = true;
return;
}
connection.closing = !req.response.keep_alive or !req.keep_alive;
if (req.method == .HEAD or req.response.status.class() == .informational or
req.response.status == .no_content or req.response.status == .not_modified)
{
req.response.parser.done = true;
return;
}
switch (req.response.transfer_encoding) {
.none => {
if (req.response.content_length) |cl| {
req.response.parser.next_chunk_length = cl;
if (cl == 0) req.response.parser.done = true;
} else {
req.response.parser.next_chunk_length = std.math.maxInt(u64);
}
},
.chunked => {
req.response.parser.next_chunk_length = 0;
req.response.parser.state = .chunk_head_size;
},
}
if (req.response.status.class() == .redirect and req.redirect_behavior != .unhandled) {
req.response.skip = true;
assert(try req.transferRead(&.{}) == 0);
if (req.redirect_behavior == .not_allowed) return error.TooManyHttpRedirects;
const location = req.response.location orelse
return error.HttpRedirectLocationMissing;
const header_buffer = req.response.parser.header_bytes_buffer;
const new_uri = req.uri.resolve_inplace(location, header_buffer) catch
return error.HttpRedirectLocationInvalid;
const path_end = new_uri.path.ptr + new_uri.path.len;
const path_offset = @intFromPtr(path_end) - @intFromPtr(header_buffer.ptr);
const end_offset = @max(path_offset, location.len);
req.response.parser.header_bytes_buffer = header_buffer[end_offset..];
const is_same_domain_or_subdomain =
std.ascii.endsWithIgnoreCase(new_uri.host.?, req.uri.host.?) and
(new_uri.host.?.len == req.uri.host.?.len or
new_uri.host.?[new_uri.host.?.len - req.uri.host.?.len - 1] == '.');
if (new_uri.host == null or !is_same_domain_or_subdomain or
!std.ascii.eqlIgnoreCase(new_uri.scheme, req.uri.scheme))
{
req.privileged_headers = &.{};
}
if (switch (req.response.status) {
.see_other => true,
.moved_permanently, .found => req.method == .POST,
else => false,
}) {
req.method = .GET;
req.transfer_encoding = .none;
req.headers.content_type = .omit;
}
if (req.transfer_encoding != .none) {
return error.RedirectRequiresResend;
}
try req.redirect(new_uri);
try req.send(.{});
} else {
req.response.skip = false;
if (!req.response.parser.done) {
switch (req.response.transfer_compression) {
.identity => req.response.compression = .none,
.compress, .@"x-compress" => return error.CompressionUnsupported,
.deflate => req.response.compression = .{
.deflate = std.compress.zlib.decompressor(req.transferReader()),
},
.gzip, .@"x-gzip" => req.response.compression = .{
.gzip = std.compress.gzip.decompressor(req.transferReader()),
},
.zstd => return error.CompressionUnsupported,
}
}
break;
}
}
}
pub const ReadError = TransferReadError || proto.HeadersParser.CheckCompleteHeadError ||
error{ DecompressionFailure, InvalidTrailers };
pub const Reader = std.io.Reader(*Request, ReadError, read);
pub fn reader(req: *Request) Reader {
return .{ .context = req };
}
pub fn read(req: *Request, buffer: []u8) ReadError!usize {
const out_index = switch (req.response.compression) {
.deflate => |*deflate| deflate.read(buffer) catch return error.DecompressionFailure,
.gzip => |*gzip| gzip.read(buffer) catch return error.DecompressionFailure,
else => try req.transferRead(buffer),
};
if (out_index > 0) return out_index;
while (!req.response.parser.state.isContent()) {
try req.connection.?.fill();
const nchecked = try req.response.parser.checkCompleteHead(req.connection.?.peek());
req.connection.?.drop(@intCast(nchecked));
}
return 0;
}
pub fn readAll(req: *Request, buffer: []u8) !usize {
var index: usize = 0;
while (index < buffer.len) {
const amt = try read(req, buffer[index..]);
if (amt == 0) break;
index += amt;
}
return index;
}
pub const WriteError = Connection.WriteError || error{ NotWriteable, MessageTooLong };
pub const Writer = std.io.Writer(*Request, WriteError, write);
pub fn writer(req: *Request) Writer {
return .{ .context = req };
}
pub fn write(req: *Request, bytes: []const u8) WriteError!usize {
switch (req.transfer_encoding) {
.chunked => {
if (bytes.len > 0) {
try req.connection.?.writer().print("{x}\r\n", .{bytes.len});
try req.connection.?.writer().writeAll(bytes);
try req.connection.?.writer().writeAll("\r\n");
}
return bytes.len;
},
.content_length => |*len| {
if (len.* < bytes.len) return error.MessageTooLong;
const amt = try req.connection.?.write(bytes);
len.* -= amt;
return amt;
},
.none => return error.NotWriteable,
}
}
pub fn writeAll(req: *Request, bytes: []const u8) WriteError!void {
var index: usize = 0;
while (index < bytes.len) {
index += try write(req, bytes[index..]);
}
}
pub const FinishError = WriteError || error{MessageNotCompleted};
pub fn finish(req: *Request) FinishError!void {
switch (req.transfer_encoding) {
.chunked => try req.connection.?.writer().writeAll("0\r\n\r\n"),
.content_length => |len| if (len != 0) return error.MessageNotCompleted,
.none => {},
}
try req.connection.?.flush();
}
};
pub const Proxy = struct {
protocol: Connection.Protocol,
host: []const u8,
authorization: ?[]const u8,
port: u16,
supports_connect: bool,
};
pub fn deinit(client: *Client) void {
assert(client.connection_pool.used.first == null);
client.connection_pool.deinit(client.allocator);
if (!disable_tls)
client.ca_bundle.deinit(client.allocator);
client.* = undefined;
}
pub fn initDefaultProxies(client: *Client, arena: Allocator) !void {
client.connection_pool.mutex.lock();
defer client.connection_pool.mutex.unlock();
assert(client.connection_pool.used.first == null);
if (client.http_proxy == null) {
client.http_proxy = try createProxyFromEnvVar(arena, &.{
"http_proxy", "HTTP_PROXY", "all_proxy", "ALL_PROXY",
});
}
if (client.https_proxy == null) {
client.https_proxy = try createProxyFromEnvVar(arena, &.{
"https_proxy", "HTTPS_PROXY", "all_proxy", "ALL_PROXY",
});
}
}
fn createProxyFromEnvVar(arena: Allocator, env_var_names: []const []const u8) !?*Proxy {
const content = for (env_var_names) |name| {
break std.process.getEnvVarOwned(arena, name) catch |err| switch (err) {
error.EnvironmentVariableNotFound => continue,
else => |e| return e,
};
} else return null;
const uri = Uri.parse(content) catch try Uri.parseWithoutScheme(content);
const protocol = if (uri.scheme.len == 0)
.plain
else
protocol_map.get(uri.scheme) orelse return null;
const host = uri.host orelse return error.HttpProxyMissingHost;
const authorization: ?[]const u8 = if (uri.user != null or uri.password != null) a: {
const authorization = try arena.alloc(u8, basic_authorization.valueLengthFromUri(uri));
assert(basic_authorization.value(uri, authorization).len == authorization.len);
break :a authorization;
} else null;
const proxy = try arena.create(Proxy);
proxy.* = .{
.protocol = protocol,
.host = host,
.authorization = authorization,
.port = uri.port orelse switch (protocol) {
.plain => 80,
.tls => 443,
},
.supports_connect = true,
};
return proxy;
}
pub const basic_authorization = struct {
pub const max_user_len = 255;
pub const max_password_len = 255;
pub const max_value_len = valueLength(max_user_len, max_password_len);
const prefix = "Basic ";
pub fn valueLength(user_len: usize, password_len: usize) usize {
return prefix.len + std.base64.standard.Encoder.calcSize(user_len + 1 + password_len);
}
pub fn valueLengthFromUri(uri: Uri) usize {
return valueLength(
if (uri.user) |user| user.len else 0,
if (uri.password) |password| password.len else 0,
);
}
pub fn value(uri: Uri, out: []u8) []u8 {
assert(uri.user == null or uri.user.?.len <= max_user_len);
assert(uri.password == null or uri.password.?.len <= max_password_len);
@memcpy(out[0..prefix.len], prefix);
var buf: [max_user_len + ":".len + max_password_len]u8 = undefined;
const unencoded = std.fmt.bufPrint(&buf, "{s}:{s}", .{
uri.user orelse "", uri.password orelse "",
}) catch unreachable;
const base64 = std.base64.standard.Encoder.encode(out[prefix.len..], unencoded);
return out[0 .. prefix.len + base64.len];
}
};
pub const ConnectTcpError = Allocator.Error || error{ ConnectionRefused, NetworkUnreachable, ConnectionTimedOut, ConnectionResetByPeer, TemporaryNameServerFailure, NameServerFailure, UnknownHostName, HostLacksNetworkAddresses, UnexpectedConnectFailure, TlsInitializationFailed };
pub fn connectTcp(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectTcpError!*Connection {
if (client.connection_pool.findConnection(.{
.host = host,
.port = port,
.protocol = protocol,
})) |node|
return node;
if (disable_tls and protocol == .tls)
return error.TlsInitializationFailed;
const conn = try client.allocator.create(ConnectionPool.Node);
errdefer client.allocator.destroy(conn);
conn.* = .{ .data = undefined };
const stream = net.tcpConnectToHost(client.allocator, host, port) catch |err| switch (err) {
error.ConnectionRefused => return error.ConnectionRefused,
error.NetworkUnreachable => return error.NetworkUnreachable,
error.ConnectionTimedOut => return error.ConnectionTimedOut,
error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
error.TemporaryNameServerFailure => return error.TemporaryNameServerFailure,
error.NameServerFailure => return error.NameServerFailure,
error.UnknownHostName => return error.UnknownHostName,
error.HostLacksNetworkAddresses => return error.HostLacksNetworkAddresses,
else => return error.UnexpectedConnectFailure,
};
errdefer stream.close();
conn.data = .{
.stream = stream,
.tls_client = undefined,
.protocol = protocol,
.host = try client.allocator.dupe(u8, host),
.port = port,
};
errdefer client.allocator.free(conn.data.host);
if (protocol == .tls) {
if (disable_tls) unreachable;
conn.data.tls_client = try client.allocator.create(std.crypto.tls.Client);
errdefer client.allocator.destroy(conn.data.tls_client);
conn.data.tls_client.* = std.crypto.tls.Client.init(stream, client.ca_bundle, host) catch return error.TlsInitializationFailed;
conn.data.tls_client.allow_truncation_attacks = true;
}
client.connection_pool.addUsed(conn);
return &conn.data;
}
pub const ConnectUnixError = Allocator.Error || std.os.SocketError || error{NameTooLong} || std.os.ConnectError;
pub fn connectUnix(client: *Client, path: []const u8) ConnectUnixError!*Connection {
if (client.connection_pool.findConnection(.{
.host = path,
.port = 0,
.protocol = .plain,
})) |node|
return node;
const conn = try client.allocator.create(ConnectionPool.Node);
errdefer client.allocator.destroy(conn);
conn.* = .{ .data = undefined };
const stream = try std.net.connectUnixSocket(path);
errdefer stream.close();
conn.data = .{
.stream = stream,
.tls_client = undefined,
.protocol = .plain,
.host = try client.allocator.dupe(u8, path),
.port = 0,
};
errdefer client.allocator.free(conn.data.host);
client.connection_pool.addUsed(conn);
return &conn.data;
}
pub fn connectTunnel(
client: *Client,
proxy: *Proxy,
tunnel_host: []const u8,
tunnel_port: u16,
) !*Connection {
if (!proxy.supports_connect) return error.TunnelNotSupported;
if (client.connection_pool.findConnection(.{
.host = tunnel_host,
.port = tunnel_port,
.protocol = proxy.protocol,
})) |node|
return node;
var maybe_valid = false;
(tunnel: {
const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol);
errdefer {
conn.closing = true;
client.connection_pool.release(client.allocator, conn);
}
const uri: Uri = .{
.scheme = "http",
.user = null,
.password = null,
.host = tunnel_host,
.port = tunnel_port,
.path = "",
.query = null,
.fragment = null,
};
var buffer: [8096]u8 = undefined;
var req = client.open(.CONNECT, uri, .{
.redirect_behavior = .unhandled,
.connection = conn,
.server_header_buffer = &buffer,
}) catch |err| {
std.log.debug("err {}", .{err});
break :tunnel err;
};
defer req.deinit();
req.send(.{ .raw_uri = true }) catch |err| break :tunnel err;
req.wait() catch |err| break :tunnel err;
if (req.response.status.class() == .server_error) {
maybe_valid = true;
break :tunnel error.ServerError;
}
if (req.response.status != .ok) break :tunnel error.ConnectionRefused;
req.connection = null;
client.allocator.free(conn.host);
conn.host = try client.allocator.dupe(u8, tunnel_host);
errdefer client.allocator.free(conn.host);
conn.port = tunnel_port;
conn.closing = false;
return conn;
}) catch {
proxy.supports_connect = maybe_valid;
return error.TunnelNotSupported;
};
}
const ConnectErrorPartial = ConnectTcpError || error{ UnsupportedUrlScheme, ConnectionRefused };
pub const ConnectError = ConnectErrorPartial || RequestError;
pub fn connect(
client: *Client,
host: []const u8,
port: u16,
protocol: Connection.Protocol,
) ConnectError!*Connection {
const proxy = switch (protocol) {
.plain => client.http_proxy,
.tls => client.https_proxy,
} orelse return client.connectTcp(host, port, protocol);
if (std.ascii.eqlIgnoreCase(proxy.host, host) and
proxy.port == port and proxy.protocol == protocol)
{
return client.connectTcp(host, port, protocol);
}
if (proxy.supports_connect) tunnel: {
return connectTunnel(client, proxy, host, port) catch |err| switch (err) {
error.TunnelNotSupported => break :tunnel,
else => |e| return e,
};
}
const conn = try client.connectTcp(proxy.host, proxy.port, proxy.protocol);
errdefer {
conn.closing = true;
client.connection_pool.release(conn);
}
conn.proxied = true;
return conn;
}
pub const RequestError = ConnectTcpError || ConnectErrorPartial || Request.SendError ||
std.fmt.ParseIntError || Connection.WriteError ||
error{
UnsupportedUrlScheme,
UriMissingHost,
CertificateBundleLoadFailure,
UnsupportedTransferEncoding,
};
pub const RequestOptions = struct {
version: http.Version = .@"HTTP/1.1",
handle_continue: bool = true,
keep_alive: bool = true,
redirect_behavior: Request.RedirectBehavior = @enumFromInt(3),
server_header_buffer: []u8,
connection: ?*Connection = null,
headers: Request.Headers = .{},
extra_headers: []const http.Header = &.{},
privileged_headers: []const http.Header = &.{},
};
pub const protocol_map = std.ComptimeStringMap(Connection.Protocol, .{
.{ "http", .plain },
.{ "ws", .plain },
.{ "https", .tls },
.{ "wss", .tls },
});
pub fn open(
client: *Client,
method: http.Method,
uri: Uri,
options: RequestOptions,
) RequestError!Request {
if (std.debug.runtime_safety) {
for (options.extra_headers) |header| {
assert(header.name.len != 0);
assert(std.mem.indexOfScalar(u8, header.name, ':') == null);
assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null);
assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null);
}
for (options.privileged_headers) |header| {
assert(header.name.len != 0);
assert(std.mem.indexOfPosLinear(u8, header.name, 0, "\r\n") == null);
assert(std.mem.indexOfPosLinear(u8, header.value, 0, "\r\n") == null);
}
}
const protocol = protocol_map.get(uri.scheme) orelse return error.UnsupportedUrlScheme;
const port: u16 = uri.port orelse switch (protocol) {
.plain => 80,
.tls => 443,
};
const host = uri.host orelse return error.UriMissingHost;
if (protocol == .tls and @atomicLoad(bool, &client.next_https_rescan_certs, .Acquire)) {
if (disable_tls) unreachable;
client.ca_bundle_mutex.lock();
defer client.ca_bundle_mutex.unlock();
if (client.next_https_rescan_certs) {
client.ca_bundle.rescan(client.allocator) catch return error.CertificateBundleLoadFailure;
@atomicStore(bool, &client.next_https_rescan_certs, false, .Release);
}
}
const conn = options.connection orelse try client.connect(host, port, protocol);
var req: Request = .{
.uri = uri,
.client = client,
.connection = conn,
.keep_alive = options.keep_alive,
.method = method,
.version = options.version,
.transfer_encoding = .none,
.redirect_behavior = options.redirect_behavior,
.handle_continue = options.handle_continue,
.response = .{
.version = undefined,
.status = undefined,
.reason = undefined,
.keep_alive = undefined,
.parser = proto.HeadersParser.init(options.server_header_buffer),
},
.headers = options.headers,
.extra_headers = options.extra_headers,
.privileged_headers = options.privileged_headers,
};
errdefer req.deinit();
return req;
}
pub const FetchOptions = struct {
server_header_buffer: ?[]u8 = null,
redirect_behavior: ?Request.RedirectBehavior = null,
response_storage: ResponseStorage = .ignore,
max_append_size: ?usize = null,
location: Location,
method: ?http.Method = null,
payload: ?[]const u8 = null,
raw_uri: bool = false,
keep_alive: bool = true,
headers: Request.Headers = .{},
extra_headers: []const http.Header = &.{},
privileged_headers: []const http.Header = &.{},
pub const Location = union(enum) {
url: []const u8,
uri: Uri,
};
pub const ResponseStorage = union(enum) {
ignore,
static: *std.ArrayListUnmanaged(u8),
dynamic: *std.ArrayList(u8),
};
};
pub const FetchResult = struct {
status: http.Status,
};
pub fn fetch(client: *Client, options: FetchOptions) !FetchResult {
const uri = switch (options.location) {
.url => |u| try Uri.parse(u),
.uri => |u| u,
};
var server_header_buffer: [16 * 1024]u8 = undefined;
const method: http.Method = options.method orelse
if (options.payload != null) .POST else .GET;
var req = try open(client, method, uri, .{
.server_header_buffer = options.server_header_buffer orelse &server_header_buffer,
.redirect_behavior = options.redirect_behavior orelse
if (options.payload == null) @enumFromInt(3) else .unhandled,
.headers = options.headers,
.extra_headers = options.extra_headers,
.privileged_headers = options.privileged_headers,
.keep_alive = options.keep_alive,
});
defer req.deinit();
if (options.payload) |payload| req.transfer_encoding = .{ .content_length = payload.len };
try req.send(.{ .raw_uri = options.raw_uri });
if (options.payload) |payload| try req.writeAll(payload);
try req.finish();
try req.wait();
switch (options.response_storage) {
.ignore => {
req.response.skip = true;
assert(try req.transferRead(&.{}) == 0);
},
.dynamic => |list| {
const max_append_size = options.max_append_size orelse 2 * 1024 * 1024;
try req.reader().readAllArrayList(list, max_append_size);
},
.static => |list| {
const buf = b: {
const buf = list.unusedCapacitySlice();
if (options.max_append_size) |len| {
if (len < buf.len) break :b buf[0..len];
}
break :b buf;
};
list.items.len += try req.reader().readAll(buf);
},
}
return .{
.status = req.response.status,
};
}
test {
_ = &initDefaultProxies;
}