diff --git a/apps/documentation/server/a2a-server.ts b/apps/documentation/server/a2a-server.ts index 7086c87b1..82ecd430f 100644 --- a/apps/documentation/server/a2a-server.ts +++ b/apps/documentation/server/a2a-server.ts @@ -1,4 +1,3 @@ -import { cors } from '@elysiajs/cors' import { CORE_PORTS, getLocalhostHost, @@ -144,6 +143,25 @@ const AGENT_CARD = { ], } as const +function resolveOrigin(origin?: string | null): string | null { + if (!origin) return null + return ALLOWED_ORIGINS.includes(origin) ? origin : null +} + +function applyCorsHeaders( + set: { headers: Record }, + request: Request, +) { + const allowedOrigin = resolveOrigin(request.headers.get('origin')) + if (!allowedOrigin) return + set.headers['Access-Control-Allow-Origin'] = allowedOrigin + set.headers['Access-Control-Allow-Credentials'] = 'true' + set.headers['Access-Control-Allow-Headers'] = 'Content-Type' + set.headers['Access-Control-Allow-Methods'] = 'GET,POST,OPTIONS' + set.headers['Access-Control-Max-Age'] = '86400' + set.headers.Vary = 'Origin' +} + /** Validate documentation page path (no traversal allowed) */ function validateDocPath(pagePath: string): string { // Normalize and check for path traversal @@ -204,16 +222,14 @@ async function executeSkill( } export const app = new Elysia() - .use( - cors({ - origin: (request) => { - const origin = request.headers.get('origin') - if (!origin) return true - return ALLOWED_ORIGINS.includes(origin) - }, - credentials: true, - }), - ) + .onRequest(({ request, set }) => { + applyCorsHeaders(set, request) + return + }) + .options('*', ({ request, set }) => { + applyCorsHeaders(set, request) + return new Response(null, { status: 204 }) + }) .derive(({ request, server }) => { const forwarded = request.headers.get('x-forwarded-for') const clientIp = diff --git a/apps/documentation/tests/unit/a2a.test.ts b/apps/documentation/tests/unit/a2a.test.ts index f148eeb90..e14a31cd7 100644 --- a/apps/documentation/tests/unit/a2a.test.ts +++ b/apps/documentation/tests/unit/a2a.test.ts @@ -37,8 +37,8 @@ describe('A2A Server Structure', () => { test('has proper CORS configuration', async () => { const serverCode = await Bun.file(SERVER_PATH).text() - expect(serverCode).toContain('cors(') expect(serverCode).toContain('ALLOWED_ORIGINS') + expect(serverCode).toContain('Access-Control-Allow-Origin') }) }) diff --git a/apps/dws/api/rlaif/trainers/grpo.py b/apps/dws/api/rlaif/trainers/grpo.py index c5ecce970..2a760295f 100644 --- a/apps/dws/api/rlaif/trainers/grpo.py +++ b/apps/dws/api/rlaif/trainers/grpo.py @@ -86,32 +86,34 @@ def setup(self, reference_model_cid: str | None = None): self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) - self.model = AutoModelForCausalLM.from_pretrained( + model = cast(PreTrainedModel, AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.bfloat16, trust_remote_code=True - ) + )) device = torch.device(self.device) - self.model = cast(PreTrainedModel, torch.nn.Module.to(self.model, device)) - self.model.gradient_checkpointing_enable() - self.model.train() + model = cast(PreTrainedModel, torch.nn.Module.to(model, device)) + model.gradient_checkpointing_enable() + model.train() + self.model = model # Load or clone reference model if reference_model_cid: ref_path = self._download_model(reference_model_cid) - self.ref_model = AutoModelForCausalLM.from_pretrained( + ref_model = cast(PreTrainedModel, AutoModelForCausalLM.from_pretrained( ref_path, torch_dtype=torch.bfloat16, trust_remote_code=True - ) + )) else: # Clone current model as reference - self.ref_model = AutoModelForCausalLM.from_pretrained( + ref_model = cast(PreTrainedModel, AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.bfloat16, trust_remote_code=True - ) + )) - self.ref_model = cast(PreTrainedModel, torch.nn.Module.to(self.ref_model, device)) - self.ref_model.eval() - for param in self.ref_model.parameters(): + ref_model = cast(PreTrainedModel, torch.nn.Module.to(ref_model, device)) + ref_model.eval() + for param in ref_model.parameters(): param.requires_grad = False + self.ref_model = ref_model - self.optimizer = AdamW(self.model.parameters(), lr=self.learning_rate) + self.optimizer = AdamW(model.parameters(), lr=self.learning_rate) logger.info(f"Model loaded on {self.device}") diff --git a/apps/oauth3/api/routes/auth-init.ts b/apps/oauth3/api/routes/auth-init.ts index 4a1ef9f71..2c636c56b 100644 --- a/apps/oauth3/api/routes/auth-init.ts +++ b/apps/oauth3/api/routes/auth-init.ts @@ -719,7 +719,6 @@ export function createAuthInitRouter(config: AuthConfig) { return { error: 'invalid_origin' } } - const appId = body.appId ?? 'jeju-default' const rpId = origin.hostname const existingCredentials = await passkeyState.listCredentialsByRpId( diff --git a/apps/oauth3/api/services/kms.ts b/apps/oauth3/api/services/kms.ts index f9ecf9b48..1c763f148 100644 --- a/apps/oauth3/api/services/kms.ts +++ b/apps/oauth3/api/services/kms.ts @@ -102,8 +102,9 @@ export async function generateSecureToken( } const config = getConfig() - const header = { alg: 'ES256K', typ: 'JWT', kid: config.jwtSigningKeyId } - const headerB64 = base64urlEncode(JSON.stringify(header)) + const headerB64 = base64urlEncode( + JSON.stringify({ alg: 'ES256K', typ: 'JWT', kid: config.jwtSigningKeyId }), + ) const payloadB64 = base64urlEncode(JSON.stringify(claims)) const signingInput = `${headerB64}.${payloadB64}` const messageHash = keccak256(toBytes(signingInput)) @@ -132,10 +133,12 @@ export async function verifySecureToken(token: string): Promise { const [headerB64, payloadB64, signatureB64] = parts - let header: { alg?: string; kid?: string } let claims: { sub?: string; iss?: string; exp?: number } try { - header = JSON.parse(base64urlDecode(headerB64)) + const header = JSON.parse(base64urlDecode(headerB64)) + if (header?.alg && header.alg !== 'ES256K') { + return null + } claims = JSON.parse(base64urlDecode(payloadB64)) } catch { return null diff --git a/apps/oauth3/api/services/sealed-oauth.ts b/apps/oauth3/api/services/sealed-oauth.ts index a72f31c6f..fb353d451 100644 --- a/apps/oauth3/api/services/sealed-oauth.ts +++ b/apps/oauth3/api/services/sealed-oauth.ts @@ -1,4 +1,8 @@ -import type { AuthProvider, SealedOAuthProvider } from '../../lib/types' +import type { + AuthProvider, + SealedOAuthProvider, + SealedSecret, +} from '../../lib/types' import { z } from 'zod' import { sealSecret, unsealSecret } from './kms' diff --git a/apps/wallet/web/components/auth/LinkedAccounts.tsx b/apps/wallet/web/components/auth/LinkedAccounts.tsx index d62b60523..fdb932eaf 100644 --- a/apps/wallet/web/components/auth/LinkedAccounts.tsx +++ b/apps/wallet/web/components/auth/LinkedAccounts.tsx @@ -33,7 +33,9 @@ type ProviderInfo = { description: string } -const PROVIDER_INFO: Record = { +type SupportedAuthProvider = Exclude + +const PROVIDER_INFO: Record = { [AuthProvider.WALLET]: { name: 'Wallet', icon: Wallet, @@ -85,13 +87,13 @@ const PROVIDER_INFO: Record = { } type LinkedProvider = { - provider: AuthProvider + provider: SupportedAuthProvider providerId: string handle?: string linkedAt: number } -const SUPPORTED_PROVIDERS: AuthProvider[] = [ +const SUPPORTED_PROVIDERS: SupportedAuthProvider[] = [ AuthProvider.WALLET, AuthProvider.GOOGLE, AuthProvider.APPLE, @@ -102,7 +104,7 @@ const SUPPORTED_PROVIDERS: AuthProvider[] = [ AuthProvider.PASSKEY, ] -function toAuthProvider(type: string): AuthProvider | null { +function toAuthProvider(type: string): SupportedAuthProvider | null { switch (type) { case 'wallet': return AuthProvider.WALLET diff --git a/package.json b/package.json index 5ae737830..a27f1b38d 100644 --- a/package.json +++ b/package.json @@ -82,7 +82,7 @@ "dev": "bun run jeju dev", "build": "bun run jeju build", "test": "bun run jeju test", - "typecheck": "turbo run typecheck --filter='!./vendor/*' && tsc --noEmit && pyright packages/training/python apps/dws/api && bun run \"typecheck:rust\"", + "typecheck": "turbo run typecheck && tsc --noEmit && pyright packages/training/python apps/dws/api && bun run \"typecheck:rust\"", "typecheck:rust": "bash -c '. $HOME/.cargo/env 2>/dev/null; mkdir -p apps/node/app/dist apps/vpn/app/dist apps/wallet/app/dist; cargo check --manifest-path apps/node/app/src-tauri/Cargo.toml && cargo check --manifest-path apps/vpn/app/src-tauri/Cargo.toml && cargo check --manifest-path apps/wallet/app/src-tauri/Cargo.toml'", "lint": "biome check --write --unsafe packages apps && ruff check packages apps --exclude packages/workerd --exclude packages/contracts/lib --fix && ruff format packages apps --exclude packages/workerd --exclude packages/contracts/lib && bun run \"lint:rust\"", "lint:rust": "bash -c '. $HOME/.cargo/env 2>/dev/null; cargo fmt --manifest-path apps/node/app/src-tauri/Cargo.toml && cargo fmt --manifest-path apps/vpn/app/src-tauri/Cargo.toml && cargo fmt --manifest-path apps/wallet/app/src-tauri/Cargo.toml'", diff --git a/packages/auth/src/credentials/verifiable-credentials.ts b/packages/auth/src/credentials/verifiable-credentials.ts index d11c05e71..755b0d679 100644 --- a/packages/auth/src/credentials/verifiable-credentials.ts +++ b/packages/auth/src/credentials/verifiable-credentials.ts @@ -309,19 +309,28 @@ export class VerifiableCredentialIssuer { } private getCredentialTypeForProvider(provider: AuthProvider): string { - const typeMap: Record = { - wallet: 'WalletOwnershipCredential', - farcaster: 'FarcasterAccountCredential', - google: 'GoogleAccountCredential', - apple: 'AppleAccountCredential', - twitter: 'TwitterAccountCredential', - github: 'GitHubAccountCredential', - discord: 'DiscordAccountCredential', - email: 'EmailAccountCredential', - phone: 'PhoneAccountCredential', - } - - return typeMap[provider] ?? 'OAuth3IdentityCredential' + switch (provider) { + case 'wallet': + return 'WalletOwnershipCredential' + case 'passkey': + return 'PasskeyAccountCredential' + case 'farcaster': + return 'FarcasterAccountCredential' + case 'google': + return 'GoogleAccountCredential' + case 'apple': + return 'AppleAccountCredential' + case 'twitter': + return 'TwitterAccountCredential' + case 'github': + return 'GitHubAccountCredential' + case 'discord': + return 'DiscordAccountCredential' + case 'email': + return 'EmailAccountCredential' + case 'phone': + return 'PhoneAccountCredential' + } } private createJWS(hash: Hex, challenge: string, domain?: string): string { @@ -594,6 +603,35 @@ export class VerifiableCredentialVerifier { } } +export function getOnChainProviderId(provider: AuthProvider): number { + switch (provider) { + case 'wallet': + return 0 + case 'farcaster': + return 1 + case 'google': + return 2 + case 'apple': + return 3 + case 'twitter': + return 4 + case 'github': + return 5 + case 'discord': + return 6 + case 'email': + return 7 + case 'phone': + return 8 + case 'passkey': + throw new Error('Passkey credentials are not yet supported on-chain') + default: { + const _exhaustive: never = provider + return _exhaustive + } + } +} + export function createCredentialHash(credential: VerifiableCredential): Hex { const essential = { type: credential.type, @@ -613,20 +651,8 @@ export function credentialToOnChainAttestation( issuedAt: number expiresAt: number } { - const providerMap: Record = { - wallet: 0, - farcaster: 1, - google: 2, - apple: 3, - twitter: 4, - github: 5, - discord: 6, - email: 7, - phone: 8, - } - return { - provider: providerMap[credential.credentialSubject.provider], + provider: getOnChainProviderId(credential.credentialSubject.provider), providerId: keccak256(toBytes(credential.credentialSubject.providerId)), credentialHash: createCredentialHash(credential), issuedAt: Math.floor(new Date(credential.issuanceDate).getTime() / 1000), diff --git a/packages/auth/src/sdk/client.ts b/packages/auth/src/sdk/client.ts index 160981f59..02e255b46 100644 --- a/packages/auth/src/sdk/client.ts +++ b/packages/auth/src/sdk/client.ts @@ -35,7 +35,7 @@ function generateUUID(): string { return `${hex.slice(0, 8)}-${hex.slice(8, 12)}-${hex.slice(12, 16)}-${hex.slice(16, 20)}-${hex.slice(20)}` } -function arrayBufferToBase64url(buffer: ArrayBuffer): string { +function arrayBufferToBase64url(buffer: ArrayBuffer | SharedArrayBuffer): string { const bytes = new Uint8Array(buffer) let binary = '' for (const byte of bytes) { @@ -56,6 +56,28 @@ function isPasskeyRequestOptions( return 'challenge' in options && !('user' in options) } +function isAuthenticatorAssertionResponse( + response: AuthenticatorResponse, +): response is AuthenticatorAssertionResponse { + return 'authenticatorData' in response && 'signature' in response +} + +function safeArrayBufferToBase64url( + value: + | ArrayBuffer + | SharedArrayBuffer + | ArrayBufferView + | null + | undefined, +): string | undefined { + if (!value) return undefined + const buffer = + value instanceof ArrayBuffer || value instanceof SharedArrayBuffer + ? value + : value.buffer.slice(value.byteOffset, value.byteOffset + value.byteLength) + return arrayBufferToBase64url(buffer) +} + // OAuth callback data schema const OAuthCallbackSchema = z.object({ code: z.string().optional(), @@ -622,21 +644,34 @@ export class OAuth3Client { if (!('attestationObject' in response)) { throw new Error('Invalid passkey registration response') } + const registrationResponse = response as AuthenticatorAttestationResponse & { + attestationObject: ArrayBuffer | SharedArrayBuffer + clientDataJSON: ArrayBuffer + } responsePayload = { clientDataJSON, - attestationObject: arrayBufferToBase64url(response.attestationObject), + attestationObject: safeArrayBufferToBase64url( + registrationResponse.attestationObject, + ), } } else { - if (!('authenticatorData' in response) || !('signature' in response)) { + if (!isAuthenticatorAssertionResponse(response)) { throw new Error('Invalid passkey authentication response') } + const assertionResponse = response as AuthenticatorAssertionResponse & { + authenticatorData: ArrayBuffer | SharedArrayBuffer + signature: ArrayBuffer | SharedArrayBuffer + userHandle?: ArrayBuffer | SharedArrayBuffer | null + } responsePayload = { clientDataJSON, - authenticatorData: arrayBufferToBase64url(response.authenticatorData), - signature: arrayBufferToBase64url(response.signature), - userHandle: response.userHandle - ? arrayBufferToBase64url(response.userHandle) - : undefined, + authenticatorData: safeArrayBufferToBase64url( + assertionResponse.authenticatorData, + ), + signature: safeArrayBufferToBase64url(assertionResponse.signature), + userHandle: safeArrayBufferToBase64url( + assertionResponse.userHandle, + ), } } diff --git a/packages/auth/src/types.ts b/packages/auth/src/types.ts index 52cf6e77c..f10cea210 100644 --- a/packages/auth/src/types.ts +++ b/packages/auth/src/types.ts @@ -5,9 +5,11 @@ * threshold MPC signing, and W3C Verifiable Credentials. */ -import type { JsonRecord, TEEAttestation } from '@jejunetwork/types' +import type { TEEAttestation } from '@jejunetwork/types' import type { Address, Hex } from 'viem' +export type { JsonRecord } from '@jejunetwork/types' + export const AuthProvider = { WALLET: 'wallet', PASSKEY: 'passkey', diff --git a/packages/bridge/package.json b/packages/bridge/package.json index 573a13681..3ee73a11a 100644 --- a/packages/bridge/package.json +++ b/packages/bridge/package.json @@ -12,7 +12,8 @@ "import": "./dist/index.js" }, "./tee": { - "types": "./dist/tee/index.d.ts", + "bun": "./src/tee/index.ts", + "types": "./src/tee/index.ts", "import": "./dist/tee/index.js" }, "./config/*": "./config/*.json" diff --git a/packages/cli/src/commands/login.ts b/packages/cli/src/commands/login.ts index b7b21b8b7..130e86995 100644 --- a/packages/cli/src/commands/login.ts +++ b/packages/cli/src/commands/login.ts @@ -212,6 +212,11 @@ async function authenticateWithDWS( export const loginCommand = new Command('login') .description('Authenticate with Jeju Network using your wallet') .option('-n, --network ', 'Network to authenticate with', 'testnet') + .option( + '--address
', + 'Wallet address to authenticate (required for --external mode)', + ) + .option('--signature ', 'Wallet signature from --external flow') .option( '-k, --private-key ', 'Private key (or use DEPLOYER_PRIVATE_KEY env)', @@ -256,6 +261,15 @@ export const loginCommand = new Command('login') } if (options.external) { + if (!options.address) { + logger.error('External auth requires --address.') + logger.info( + 'Example: jeju login --external --address 0xYourAddress --network localnet', + ) + return + } + + const address = options.address as Address // External signing - output message for user to sign elsewhere const nonce = bytesToHex(randomBytes(32)) const timestamp = Date.now() @@ -266,14 +280,50 @@ export const loginCommand = new Command('login') timestamp, ) - logger.info('Sign the following message with your wallet:\n') - console.log('---') - console.log(message) - console.log('---\n') + if (!options.signature) { + logger.info('Sign the following message with your wallet:\n') + console.log('---') + console.log(message) + console.log('---\n') + + logger.info('Then run:') + logger.info( + `jeju login --network ${network} --address ${address} --signature `, + ) + return + } + + // Complete external login with provided signature + const signature = options.signature + const isValid = await verifyMessage({ address, message, signature }) + if (!isValid) { + logger.error('Signature verification failed') + return + } + + const authResult = await authenticateWithDWS( + address, + signature, + message, + network, + ) + + const credentials: Credentials = { + version: 1, + network, + address, + keyType: 'external', + authToken: authResult.token, + createdAt: Date.now(), + expiresAt: authResult.expiresAt, + } - logger.info('Then run:') + saveCredentials(credentials) + logger.success(`Logged in as ${address}`) + logger.info(`Network: ${network}`) + logger.info(`Expires at: ${new Date(authResult.expiresAt).toLocaleDateString()}`) logger.info( - `jeju login --network ${network} --signature --address `, + 'Use `jeju login` again if your token expires or you change wallets.', ) return } diff --git a/packages/monitoring/package.json b/packages/monitoring/package.json index c4aca5921..fee33f5f9 100644 --- a/packages/monitoring/package.json +++ b/packages/monitoring/package.json @@ -3,6 +3,7 @@ "version": "0.1.0", "dependencies": { "@elysiajs/cors": "^1.4.0", + "@jejunetwork/auth": "workspace:*", "@jejunetwork/cache": "workspace:*", "@jejunetwork/config": "workspace:*", "@jejunetwork/types": "workspace:*", diff --git a/packages/sqlit/src/client.ts b/packages/sqlit/src/client.ts index b447eb7a7..d1d97470e 100644 --- a/packages/sqlit/src/client.ts +++ b/packages/sqlit/src/client.ts @@ -296,7 +296,12 @@ export class SQLitClient { return rows.filter(isRecordRow) } - if (isRecordRow(rows[0])) { + const firstRow = rows[0] + if (firstRow === undefined) { + return [] + } + + if (isRecordRow(firstRow)) { return rows.filter(isRecordRow) } diff --git a/packages/training/python/src/training/jeju_env.py b/packages/training/python/src/training/jeju_env.py index e052a4531..b5b3a8f4f 100644 --- a/packages/training/python/src/training/jeju_env.py +++ b/packages/training/python/src/training/jeju_env.py @@ -20,7 +20,7 @@ import os import random from contextlib import AbstractAsyncContextManager -from typing import TYPE_CHECKING, ClassVar, Optional, Protocol, TypedDict, cast +from typing import TYPE_CHECKING, ClassVar, Optional, Protocol, TypeAlias, TypedDict, cast if TYPE_CHECKING: from .tinker_client import JejuTinkerClient @@ -100,8 +100,9 @@ class _Rollout(TypedDict): finish_reason: str -class ScoredDataGroupWithInferenceLogprobs(ScoredDataGroup, total=False): - inference_logprobs: list[list[float]] +# Atropos expects a typed mapping for scored data; keep this as a local alias +# to avoid re-deriving assumptions from a third-party TypedDict inheritance pattern. +ScoredDataGroupWithInferenceLogprobs: TypeAlias = ScoredDataGroup class JejuEnvConfig(BaseEnvConfig): @@ -653,7 +654,7 @@ async def _score_with_judge(self, rollout_data: list[_Rollout]) -> ScoredDataGro ] images_list: list[list[str]] = [[] for _ in rollout_data] - scored_group: ScoredDataGroupWithInferenceLogprobs = { + scored_group = { "tokens": tokens_list, "masks": masks_list, "scores": centered_scores, @@ -666,7 +667,7 @@ async def _score_with_judge(self, rollout_data: list[_Rollout]) -> ScoredDataGro "inference_logprobs": logprobs_list, } - return scored_group + return cast(ScoredDataGroupWithInferenceLogprobs, scored_group) async def evaluate(self, *args, **kwargs): # noqa: ARG002 """Evaluate current model performance"""